Skip to content

Commit decb191

Browse files
authored
[proto] Small optimization for gaussian_blur functional op (#6762)
* Use softmax in _get_gaussian_kernel1d * Revert "Use softmax in _get_gaussian_kernel1d" This reverts commit eb8fba3. * Code update * Relaxed tolerance in consistency tests for GaussianBlur and ElasticTransform * Code review updates * Update test_prototype_transforms_consistency.py
1 parent 149edda commit decb191

File tree

2 files changed

+53
-18
lines changed

2 files changed

+53
-18
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -308,22 +308,28 @@ def __init__(
308308
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3),
309309
],
310310
),
311-
ConsistencyConfig(
312-
prototype_transforms.ElasticTransform,
313-
legacy_transforms.ElasticTransform,
314-
[
315-
ArgsKwargs(),
316-
ArgsKwargs(alpha=20.0),
317-
ArgsKwargs(alpha=(15.3, 27.2)),
318-
ArgsKwargs(sigma=3.0),
319-
ArgsKwargs(sigma=(2.5, 3.9)),
320-
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
321-
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
322-
ArgsKwargs(fill=1),
323-
],
324-
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
325-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]),
326-
),
311+
*[
312+
ConsistencyConfig(
313+
prototype_transforms.ElasticTransform,
314+
legacy_transforms.ElasticTransform,
315+
[
316+
ArgsKwargs(),
317+
ArgsKwargs(alpha=20.0),
318+
ArgsKwargs(alpha=(15.3, 27.2)),
319+
ArgsKwargs(sigma=3.0),
320+
ArgsKwargs(sigma=(2.5, 3.9)),
321+
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
322+
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
323+
ArgsKwargs(fill=1),
324+
],
325+
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
326+
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
327+
# We updated gaussian blur kernel generation with a faster and numerically more stable version
328+
# This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
329+
closeness_kwargs=ckw,
330+
)
331+
for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
332+
],
327333
ConsistencyConfig(
328334
prototype_transforms.GaussianBlur,
329335
legacy_transforms.GaussianBlur,
@@ -333,6 +339,7 @@ def __init__(
333339
ArgsKwargs(kernel_size=3, sigma=0.7),
334340
ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
335341
],
342+
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
336343
),
337344
ConsistencyConfig(
338345
prototype_transforms.RandomAffine,
@@ -506,7 +513,6 @@ def check_call_consistency(
506513
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
507514

508515
image_tensor = torch.Tensor(image)
509-
510516
try:
511517
torch.manual_seed(0)
512518
output_legacy_tensor = legacy_transform(image_tensor)

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import math
12
from typing import List, Optional, Union
23

34
import PIL.Image
45
import torch
6+
from torch.nn.functional import conv2d, pad as torch_pad
57
from torchvision.prototype import features
68
from torchvision.transforms import functional_tensor as _FT
79
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
@@ -32,6 +34,22 @@ def normalize(
3234
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
3335

3436

37+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
38+
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
39+
x = torch.linspace(-lim, lim, steps=kernel_size)
40+
kernel1d = torch.softmax(-x.pow_(2), dim=0)
41+
return kernel1d
42+
43+
44+
def _get_gaussian_kernel2d(
45+
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
46+
) -> torch.Tensor:
47+
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
48+
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
49+
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
50+
return kernel2d
51+
52+
3553
def gaussian_blur_image_tensor(
3654
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
3755
) -> torch.Tensor:
@@ -70,7 +88,18 @@ def gaussian_blur_image_tensor(
7088
else:
7189
needs_unsquash = False
7290

73-
output = _FT.gaussian_blur(image, kernel_size, sigma)
91+
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
92+
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
93+
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])
94+
95+
image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])
96+
97+
# padding = (left, right, top, bottom)
98+
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
99+
output = torch_pad(image, padding, mode="reflect")
100+
output = conv2d(output, kernel, groups=output.shape[-3])
101+
102+
output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)
74103

75104
if needs_unsquash:
76105
output = output.reshape(shape)

0 commit comments

Comments
 (0)