diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 81bfa74acce..79a2b591a59 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -655,6 +655,39 @@ def test_call_consistency(config, args_kwargs): ) +@pytest.mark.parametrize( + "config", + [config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")], + ids=lambda config: config.legacy_cls.__name__, +) +def test_get_params_alias(config): + assert config.prototype_cls.get_params is config.legacy_cls.get_params + + +@pytest.mark.parametrize( + ("transform_cls", "args_kwargs"), + [ + (prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), + (prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), + (prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), + (prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), + (prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), + ( + prototype_transforms.RandomAffine, + ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), + ), + (prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), + (prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), + (prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), + (prototype_transforms.AutoAugment, ArgsKwargs(5)), + ], +) +def test_get_params_jit(transform_cls, args_kwargs): + args, kwargs = args_kwargs + + torch.jit.script(transform_cls.get_params)(*args, **kwargs) + + @pytest.mark.parametrize( ("config", "args_kwargs"), [ diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index a1fb3846a24..18678a5265a 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -56,10 +56,19 @@ def extra_repr(self) -> str: return ", ".join(extra) - # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation - # to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details. + # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things: + # 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on + # the v2 transform. See `__init_subclass__` for details. + # 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__` + # for details. _v1_transform_cls: Optional[Type[nn.Module]] = None + def __init_subclass__(cls) -> None: + # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance. + # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`. + if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"): + cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined] + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current # v2 transform instance. It does two things: