-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add tests for transform presets, and various fixes #7223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
6d946cf
c845fd4
3d56ea2
6a67591
e0ae71d
6e0b04e
98df638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import itertools | ||
import re | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
|
||
|
@@ -1993,7 +1994,8 @@ def test__transform(self, inpt): | |
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) | ||
@pytest.mark.parametrize("label_type", (torch.Tensor, int)) | ||
@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) | ||
def test_classif_preset(image_type, label_type, dataset_return_type): | ||
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) | ||
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): | ||
|
||
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) | ||
if image_type is PIL.Image: | ||
|
@@ -2020,10 +2022,10 @@ def test_classif_preset(image_type, label_type, dataset_return_type): | |
transforms.TrivialAugmentWide(), | ||
transforms.AugMix(), | ||
transforms.AutoAugment(), | ||
transforms.ToImageTensor(), | ||
to_tensor(), | ||
# TODO: ConvertImageDtype is a pass-through on PIL images, is that | ||
# intended? This results in a failure if ToImageTensor() is called | ||
# after it, because the image would still be uint8 which make Normalize | ||
# intended? This results in a failure if we convert to tensor after | ||
# it, because the image would still be uint8 which make Normalize | ||
# fail. | ||
transforms.ConvertImageDtype(torch.float), | ||
transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]), | ||
|
@@ -2043,3 +2045,102 @@ def test_classif_preset(image_type, label_type, dataset_return_type): | |
|
||
assert out_image.shape[-2:] == (224, 224) | ||
assert out_label == label | ||
|
||
|
||
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) | ||
@pytest.mark.parametrize("label_type", (torch.Tensor, list)) | ||
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) | ||
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) | ||
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): | ||
if data_augmentation == "hflip": | ||
t = [ | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "lsj": | ||
t = [ | ||
transforms.ScaleJitter(target_size=(1024, 1024), antialias=True), | ||
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're | ||
# leaving FixedSizeCrop in prototype for now, and it expects Label | ||
# classes which we won't release yet. | ||
# transforms.FixedSizeCrop( | ||
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}) | ||
# ), | ||
# TODO: I have to set a very low size for the crop (30, 30), | ||
# otherwise we'd get an error saying the crop is larger than the | ||
# image. This means RandomCrop doesn't do the same thing as | ||
# FixedSizeCrop and we need ot figure out the key differences | ||
transforms.RandomCrop((30, 30)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "multiscale": | ||
t = [ | ||
transforms.RandomShortestSize( | ||
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True | ||
), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "ssd": | ||
t = [ | ||
transforms.RandomPhotometricDistort(p=1), | ||
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), | ||
# TODO: put back IoUCrop once we remove its hard requirement for Labels | ||
# transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
elif data_augmentation == "ssdlite": | ||
t = [ | ||
# TODO: put back IoUCrop once we remove its hard requirement for Labels | ||
# transforms.RandomIoUCrop(), | ||
transforms.RandomHorizontalFlip(p=1), | ||
to_tensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
] | ||
t = transforms.Compose(t) | ||
|
||
num_boxes = 5 | ||
H = W = 250 | ||
|
||
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)) | ||
if image_type is PIL.Image: | ||
image = to_pil_image(image[0]) | ||
elif image_type is torch.Tensor: | ||
image = image.as_subclass(torch.Tensor) | ||
assert is_simple_tensor(image) | ||
|
||
label = torch.randint(0, 10, size=(num_boxes,)) | ||
if label_type is list: | ||
label = label.tolist() | ||
|
||
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks | ||
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) | ||
boxes[:, 2:] += boxes[:, :2] | ||
boxes = boxes.clamp(min=0, max=min(H, W)) | ||
boxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=(H, W)) | ||
|
||
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) | ||
|
||
sample = { | ||
"image": image, | ||
"label": label, | ||
"boxes": boxes, | ||
"masks": masks, | ||
} | ||
|
||
out = t(sample) | ||
|
||
if to_tensor is transforms.ToTensor and image_type is not datapoints.Image: | ||
assert is_simple_tensor(out["image"]) | ||
else: | ||
assert isinstance(out["image"], datapoints.Image) | ||
assert isinstance(out["label"], type(sample["label"])) | ||
|
||
out["label"] = torch.tensor(out["label"]) | ||
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,7 +169,8 @@ def _permute_channels( | |
if isinstance(orig_inpt, PIL.Image.Image): | ||
inpt = F.pil_to_tensor(inpt) | ||
|
||
output = inpt[..., permutation, :, :] | ||
# TODO: Find a better fix than as_subclass??? | ||
output = inpt[..., permutation, :, :].as_subclass(type(inpt)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am concerned that this slipped through the cracks: passing an Image would result in a pure Tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We do for the functional part, but this is a transform. We need to be extra careful for any kind of functionality that is implemented directly on the transform rather than in a dispatcher. |
||
|
||
if isinstance(orig_inpt, PIL.Image.Image): | ||
output = F.to_image_pil(output) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test or the one below may not be your preferred test style. I'm happy to come back to this when we have more time,, but considering the very pressing timeline with the upcoming release, I kindly request that we focus on the substance (i.e. what the tests are testing + how we fix them + correctness) rather than style / perf / optimizations.