Skip to content

make transforms v2 JIT scriptable #7135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,106 @@ def test_call_consistency(config, args_kwargs):
)


@pytest.mark.parametrize(
("config", "args_kwargs"),
[
pytest.param(
config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
)
for config in CONSISTENCY_CONFIGS
for idx, args_kwargs in enumerate(config.args_kwargs or [None])
],
)
def test_jit_and_get_params_consistency(config, args_kwargs):
def check_get_params(prototype_cls, legacy_cls):
if not hasattr(config.prototype_cls, "get_params"):
raise AssertionError(
"The legacy transform defines a `get_params`, but the prototype does not. "
"Did you forget to set the `_v1_transforms_cls` class attribute on the prototype transform?"
)

try:
torch.jit.script(legacy_cls.get_params)
except Exception as exc:
raise pytest.UsageError("The `get_params` method of the legacy transform cannot be scripted!") from exc

try:
torch.jit.script(prototype_cls.get_params)
except Exception as exc:
raise AssertionError(
"Can't script the prototype `get_params` method. "
"This means there is a bug in the automatic aliasing from the corresponding legacy transform."
) from exc

def check_call(prototype_transform_eager, legacy_transform_eager, *, images, closeness_kwargs):
try:
legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
except Exception as exc:
msg = str(exc)

# Some of the transform parameter variations cannot be used while scripting. For example, `Resize(size=1)`
# is not scriptable. One has to use `Resize(size=[1])` instead. To avoid creating a second set of
# parameters, we just abort the JIT call check in such a case.
if (
re.search(
r"Expected a value of type 'List\[int\]' for argument '\w+' but instead found type 'int'", msg
)
is not None
):
return

# This error happens when `torch.jit.script` hits an unguarded `_log_api_usage_once`. Since that means that
# the transform doesn't support scripting in general, we abort the JIT call check.
if re.search(r"'Any' object has no attribute or method '__module__'", msg) is not None:
return

raise pytest.UsageError("The legacy transform cannot be scripted!") from exc

try:
prototype_transform_scripted = torch.jit.script(prototype_transform_eager)
except Exception as exc:
raise AssertionError("The prototype transform cannot be scripted!") from exc

for image in images:
image = image.as_subclass(torch.Tensor)
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"

try:
torch.manual_seed(0)
output_legacy_scripted = legacy_transform_scripted(image)
except Exception as exc:
raise AssertionError(
f"Calling the scripted legacy transform with {image_repr} raised the error above!"
) from exc

try:
torch.manual_seed(0)
output_prototype_scripted = prototype_transform_scripted(image)
except Exception as exc:
raise AssertionError(
f"Calling the scripted prototype transform with {image_repr} raised the error above!"
) from exc

assert_close(
output_prototype_scripted,
output_legacy_scripted,
msg=lambda msg: f"JIT runtime consistency check failed with: \n\n{msg}",
**closeness_kwargs,
)

if hasattr(config.legacy_cls, "get_params"):
check_get_params(config.prototype_cls, config.legacy_cls)

if args_kwargs is not None:
args, kwargs = args_kwargs
check_call(
config.prototype_cls(*args, **kwargs),
config.legacy_cls(*args, **kwargs),
images=make_images(**config.make_images_kwargs),
closeness_kwargs=config.closeness_kwargs,
)


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchvision import transforms as _transforms
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
Expand All @@ -16,6 +16,14 @@


class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if params["value"] is None:
params["value"] = "random"
return params

_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)

def __init__(
Expand Down
8 changes: 7 additions & 1 deletion torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
Expand Down Expand Up @@ -161,6 +161,8 @@ def _apply_image_or_video_transform(


class AutoAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.AutoAugment

_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down Expand Up @@ -315,6 +317,7 @@ def forward(self, *inputs: Any) -> Any:


class RandAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.RandAugment
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down Expand Up @@ -375,6 +378,7 @@ def forward(self, *inputs: Any) -> Any:


class TrivialAugmentWide(_AutoAugmentBase):
_v1_transform_cls = _transforms.TrivialAugmentWide
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
Expand Down Expand Up @@ -425,6 +429,8 @@ def forward(self, *inputs: Any) -> Any:


class AugMix(_AutoAugmentBase):
_v1_transform_cls = _transforms.AugMix

_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down
23 changes: 22 additions & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import PIL.Image
import torch

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform

Expand All @@ -12,6 +12,8 @@


class Grayscale(Transform):
_v1_transform_cls = _transforms.Grayscale

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -28,6 +30,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomGrayscale

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -47,6 +51,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ColorJitter(Transform):
_v1_transform_cls = _transforms.ColorJitter

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}

def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
Expand Down Expand Up @@ -194,16 +203,22 @@ def _transform(


class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomEqualize

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)


class RandomInvert(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomInvert

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)


class RandomPosterize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPosterize

def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
Expand All @@ -213,6 +228,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomSolarize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomSolarize

def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
Expand All @@ -222,11 +239,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAutocontrast

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)


class RandomAdjustSharpness(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAdjustSharpness

def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
Expand Down
27 changes: 27 additions & 0 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import PIL.Image
import torch

from torchvision import transforms as _transforms
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
Expand All @@ -27,16 +28,22 @@


class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomHorizontalFlip

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)


class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomVerticalFlip

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)


class Resize(Transform):
_v1_transform_cls = _transforms.Resize

def __init__(
self,
size: Union[int, Sequence[int]],
Expand Down Expand Up @@ -66,6 +73,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class CenterCrop(Transform):
_v1_transform_cls = _transforms.CenterCrop

def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
Expand All @@ -75,6 +84,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomResizedCrop(Transform):
_v1_transform_cls = _transforms.RandomResizedCrop

def __init__(
self,
size: Union[int, Sequence[int]],
Expand Down Expand Up @@ -171,6 +182,8 @@ class FiveCrop(Transform):
torch.Size([5])
"""

_v1_transform_cls = _transforms.FiveCrop

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -197,6 +210,8 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""

_v1_transform_cls = _transforms.TenCrop

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -220,6 +235,8 @@ def _transform(


class Pad(Transform):
_v1_transform_cls = _transforms.Pad

def __init__(
self,
padding: Union[int, Sequence[int]],
Expand Down Expand Up @@ -282,6 +299,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomRotation(Transform):
_v1_transform_cls = _transforms.RandomRotation

def __init__(
self,
degrees: Union[numbers.Number, Sequence],
Expand Down Expand Up @@ -319,6 +338,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomAffine(Transform):
_v1_transform_cls = _transforms.RandomAffine

def __init__(
self,
degrees: Union[numbers.Number, Sequence],
Expand Down Expand Up @@ -396,6 +417,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomCrop(Transform):
_v1_transform_cls = _transforms.RandomCrop

def __init__(
self,
size: Union[int, Sequence[int]],
Expand Down Expand Up @@ -488,6 +511,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomPerspective(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPerspective

def __init__(
self,
distortion_scale: float = 0.5,
Expand Down Expand Up @@ -547,6 +572,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ElasticTransform(Transform):
_v1_transform_cls = _transforms.ElasticTransform

def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform

Expand All @@ -27,6 +28,8 @@ def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> da


class ConvertDtype(Transform):
_v1_transform_cls = _transforms.ConvertImageDtype

_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
Expand Down
Loading