Skip to content

Add tests for transform presets, and various fixes #7223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import re
from collections import defaultdict

import numpy as np

Expand Down Expand Up @@ -1993,7 +1994,8 @@ def test__transform(self, inpt):
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
def test_classif_preset(image_type, label_type, dataset_return_type):
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test or the one below may not be your preferred test style. I'm happy to come back to this when we have more time,, but considering the very pressing timeline with the upcoming release, I kindly request that we focus on the substance (i.e. what the tests are testing + how we fix them + correctness) rather than style / perf / optimizations.


image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
if image_type is PIL.Image:
Expand All @@ -2020,10 +2022,10 @@ def test_classif_preset(image_type, label_type, dataset_return_type):
transforms.TrivialAugmentWide(),
transforms.AugMix(),
transforms.AutoAugment(),
transforms.ToImageTensor(),
to_tensor(),
# TODO: ConvertImageDtype is a pass-through on PIL images, is that
# intended? This results in a failure if ToImageTensor() is called
# after it, because the image would still be uint8 which make Normalize
# intended? This results in a failure if we convert to tensor after
# it, because the image would still be uint8 which make Normalize
# fail.
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
Expand All @@ -2043,3 +2045,102 @@ def test_classif_preset(image_type, label_type, dataset_return_type):

assert out_image.shape[-2:] == (224, 224)
assert out_label == label


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, list))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
if data_augmentation == "hflip":
t = [
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "lsj":
t = [
transforms.ScaleJitter(target_size=(1024, 1024), antialias=True),
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})
# ),
# TODO: I have to set a very low size for the crop (30, 30),
# otherwise we'd get an error saying the crop is larger than the
# image. This means RandomCrop doesn't do the same thing as
# FixedSizeCrop and we need ot figure out the key differences
transforms.RandomCrop((30, 30)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomCrop doesn't do any padding by default, but FixedSizedCrop does. You probably only need to set pad_if_needed=True.

transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "multiscale":
t = [
transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
elif data_augmentation == "ssdlite":
t = [
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor(),
transforms.ConvertImageDtype(torch.float),
]
t = transforms.Compose(t)

num_boxes = 5
H = W = 250

image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)

label = torch.randint(0, 10, size=(num_boxes,))
if label_type is list:
label = label.tolist()

# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W))

masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))

sample = {
"image": image,
"label": label,
"boxes": boxes,
"masks": masks,
}

out = t(sample)

if to_tensor is transforms.ToTensor and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"]))

out["label"] = torch.tensor(out["label"])
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes
3 changes: 2 additions & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def _permute_channels(
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

output = inpt[..., permutation, :, :]
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
Copy link
Member Author

@NicolasHug NicolasHug Feb 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that this slipped through the cracks: passing an Image would result in a pure Tensor.
@pmeier don't we have tests that make sure the types are preserved? Are we ensuring in the tests that all the random transforms are actually tested (i.e. set p=1)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have tests that make sure the types are preserved?

We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher.


if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)
Expand Down