diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 7d2f1d735ea..212755068d9 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -308,22 +308,28 @@ def __init__( ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3), ], ), - ConsistencyConfig( - prototype_transforms.ElasticTransform, - legacy_transforms.ElasticTransform, - [ - ArgsKwargs(), - ArgsKwargs(alpha=20.0), - ArgsKwargs(alpha=(15.3, 27.2)), - ArgsKwargs(sigma=3.0), - ArgsKwargs(sigma=(2.5, 3.9)), - ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), - ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), - ArgsKwargs(fill=1), - ], - # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]), - ), + *[ + ConsistencyConfig( + prototype_transforms.ElasticTransform, + legacy_transforms.ElasticTransform, + [ + ArgsKwargs(), + ArgsKwargs(alpha=20.0), + ArgsKwargs(alpha=(15.3, 27.2)), + ArgsKwargs(sigma=3.0), + ArgsKwargs(sigma=(2.5, 3.9)), + ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs(fill=1), + ], + # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]), + # We updated gaussian blur kernel generation with a faster and numerically more stable version + # This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance + closeness_kwargs=ckw, + ) + for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})] + ], ConsistencyConfig( prototype_transforms.GaussianBlur, legacy_transforms.GaussianBlur, @@ -333,6 +339,7 @@ def __init__( ArgsKwargs(kernel_size=3, sigma=0.7), ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)), ], + closeness_kwargs={"rtol": 1e-5, "atol": 1e-5}, ), ConsistencyConfig( prototype_transforms.RandomAffine, @@ -506,7 +513,6 @@ def check_call_consistency( image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]" image_tensor = torch.Tensor(image) - try: torch.manual_seed(0) output_legacy_tensor = legacy_transform(image_tensor) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 5b2dd135a60..fa4a6e9be73 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,7 +1,9 @@ +import math from typing import List, Optional, Union import PIL.Image import torch +from torch.nn.functional import conv2d, pad as torch_pad from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image @@ -32,6 +34,22 @@ def normalize( return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) +def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor: + lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma) + x = torch.linspace(-lim, lim, steps=kernel_size) + kernel1d = torch.softmax(-x.pow_(2), dim=0) + return kernel1d + + +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) + kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x + return kernel2d + + def gaussian_blur_image_tensor( image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: @@ -70,7 +88,18 @@ def gaussian_blur_image_tensor( else: needs_unsquash = False - output = _FT.gaussian_blur(image, kernel_size, sigma) + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device) + kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype]) + + # padding = (left, right, top, bottom) + padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] + output = torch_pad(image, padding, mode="reflect") + output = conv2d(output, kernel, groups=output.shape[-3]) + + output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype) if needs_unsquash: output = output.reshape(shape)