diff --git a/test/test_transforms.py b/test/test_transforms.py index 8423bf99ee3..f8101a1d862 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -532,7 +532,13 @@ def test_convert_image_dtype_float_to_float(self): for output_dtype in output_dtypes: with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -546,6 +552,7 @@ def test_convert_image_dtype_float_to_int(self): for output_dtype in int_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( input_dtype == torch.float64 and output_dtype == torch.int64 @@ -554,6 +561,10 @@ def test_convert_image_dtype_float_to_int(self): transform(input_image) else: output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, torch.iinfo(output_dtype).max @@ -567,7 +578,13 @@ def test_convert_image_dtype_int_to_float(self): for output_dtype in float_dtypes(): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0.0, 1.0 @@ -586,7 +603,13 @@ def test_convert_image_dtype_int_to_int(self): with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + script_diff = output_image_script - output_image + self.assertTrue(script_diff.abs().max() < 1e-6) actual_min, actual_max = output_image.tolist() desired_min, desired_max = 0, output_max diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e49ff063dc8..c9609774c24 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -113,6 +113,11 @@ def pil_to_tensor(pic): return img +def _is_floating_point(dtype: torch.dtype) -> bool: + # helper function since torch.dtype.is_floating_point is not scriptable + return isinstance(dtype, (torch.float32, torch.float, torch.float64, torch.double)) + + def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly @@ -137,9 +142,12 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - if image.dtype == dtype: return image - if image.dtype.is_floating_point: + input_is_float = _is_floating_point(image.dtype) + output_is_float = _is_floating_point(dtype) + + if input_is_float: # float to float - if dtype.is_floating_point: + if output_is_float: return image.to(dtype) # float to int @@ -153,7 +161,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) else: # int to float - if dtype.is_floating_point: + if output_is_float: max = torch.iinfo(image.dtype).max image = image.to(dtype) return image / max