Skip to content

introduce heuristic for simple tensor handling of transforms v2 #7170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 8, 2023
Merged
106 changes: 92 additions & 14 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import itertools
import re

import numpy as np

import PIL.Image

import pytest
import torch

import torchvision.prototype.transforms.utils
from common_utils import assert_equal, cpu_and_gpu
from common_utils import cpu_and_gpu
from prototype_common_utils import (
assert_equal,
DEFAULT_EXTRA_DIMS,
make_bounding_box,
make_bounding_boxes,
Expand All @@ -25,7 +26,7 @@
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -222,6 +223,67 @@ def test_random_resized_crop(self, transform, input):
transform(input)


@pytest.mark.parametrize(
"flat_inputs",
itertools.permutations(
[
next(make_vanilla_tensor_images()),
next(make_vanilla_tensor_images()),
next(make_pil_images()),
make_image(),
next(make_videos()),
],
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 3. A list of all other items
simple_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others

class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()

@staticmethod
def was_applied(output, inpt):
identity = output is inpt
if identity:
return False

# Make sure nothing fishy is going on
assert_equal(output, inpt)
return True

first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)

transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)

first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)

if first_simple_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)

for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
assert not transform.was_applied(output, inpt)

for input, output in zip(other_inputs, other_outputs):
assert transform.was_applied(output, input)


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
Expand Down Expand Up @@ -1760,17 +1822,17 @@ def test__transform(self, mocker):
[
(
torch.float64,
{torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
Expand All @@ -1793,22 +1855,27 @@ def test_to_dtype(dtype, expected_dtypes):
assert transformed_value is value


@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_to_dtype_plain_tensor_warning(other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})


@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None},
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None},
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
),
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
{datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
),
],
)
def test_permute_dimensions(dims, inverse_dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
Expand All @@ -1833,16 +1900,21 @@ def test_permute_dimensions(dims, inverse_dims):
assert transformed_value is value


@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_permute_dimensions_plain_tensor_warning(other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})


@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None},
{datapoints.Image: (1, 2), datapoints.Video: None},
],
)
def test_transpose_dimensions(dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
Expand All @@ -1868,6 +1940,12 @@ def test_transpose_dimensions(dims):
assert transformed_value is value


@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_transpose_dimensions_plain_tensor_warning(other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})


class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
"inpt",
Expand Down
19 changes: 19 additions & 0 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image
Expand Down Expand Up @@ -155,6 +156,12 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]])
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
Expand All @@ -171,6 +178,12 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims

def _transform(
Expand All @@ -189,6 +202,12 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims

def _transform(
Expand Down
35 changes: 31 additions & 4 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.transforms.utils import check_type
from torchvision.prototype import datapoints
from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor
from torchvision.utils import _log_api_usage_once


Expand Down Expand Up @@ -37,9 +38,35 @@ def forward(self, *inputs: Any) -> Any:

params = self._get_params(flat_inputs)

flat_outputs = [
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs
]
# Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
flat_outputs = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True

if not check_type(inpt, self._transformed_types):
needs_transform = False
elif is_simple_tensor(inpt):
if transform_simple_tensor:
transform_simple_tensor = False
else:
needs_transform = False

flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt)

return tree_unflatten(flat_outputs, spec)

Expand Down