From 0460e44bbde4ff31fa2d0bc0b16dd602bba6b956 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 22:46:23 +0100 Subject: [PATCH 1/6] add proper smoke test for prototype transforms --- test/test_prototype_transforms.py | 206 +++++++++++++++++++++++++++--- 1 file changed, 186 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 0ed51c44d77..e10fe284983 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,4 +1,5 @@ import itertools +import pathlib import re import warnings from collections import defaultdict @@ -26,9 +27,11 @@ make_video, make_videos, ) +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints, transforms -from torchvision.prototype.transforms.utils import check_type, is_simple_tensor +from torchvision.prototype.transforms import functional as F +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -92,27 +95,190 @@ def parametrize_from_transforms(*transforms): return parametrize(transforms_with_inputs) +def auto_augment_adapter(transform, input, device): + adapted_input = {} + image_or_video_found = False + for key, value in input.items(): + if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)): + # AA transforms don't support bounding boxes or masks + continue + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): + if image_or_video_found: + # AA transforms only support a single image or video + continue + image_or_video_found = True + adapted_input[key] = value + return adapted_input + + +def linear_transformation_adapter(transform, input, device): + c, h, w = query_chw(input.values()) + num_elements = c * h * w + transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device) + transform.mean_vector = torch.randn((num_elements,), device=device) + return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)} + + +def normalize_adapter(transform, input, device): + adapted_input = {} + for key, value in input.items(): + if isinstance(value, PIL.Image.Image): + # normalize doesn't support PIL images + continue + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): + # normalize doesn't support integer images + value = F.convert_dtype(value, torch.float32) + adapted_input[key] = value + return adapted_input + + class TestSmoke: - @parametrize_from_transforms( - transforms.RandomErasing(p=1.0), - transforms.Resize([16, 16], antialias=True), - transforms.CenterCrop([16, 16]), - transforms.ConvertDtype(), - transforms.RandomHorizontalFlip(), - transforms.Pad(5), - transforms.RandomZoomOut(), - transforms.RandomRotation(degrees=(-45, 45)), - transforms.RandomAffine(degrees=(-45, 45)), - transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), - # TODO: Something wrong with input data setup. Let's fix that - # transforms.RandomEqualize(), - # transforms.RandomInvert(), - # transforms.RandomPosterize(bits=4), - # transforms.RandomSolarize(threshold=0.5), - # transforms.RandomAdjustSharpness(sharpness_factor=0.5), + @pytest.mark.parametrize( + ("transform", "adapter"), + [ + (transforms.RandomErasing(p=1.0), None), + (transforms.AugMix(), auto_augment_adapter), + (transforms.AutoAugment(), auto_augment_adapter), + (transforms.RandAugment(), auto_augment_adapter), + (transforms.TrivialAugmentWide(), auto_augment_adapter), + (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), + (transforms.Grayscale(), None), + (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), + (transforms.RandomAutocontrast(p=1.0), None), + (transforms.RandomEqualize(p=1.0), None), + (transforms.RandomGrayscale(p=1.0), None), + (transforms.RandomInvert(p=1.0), None), + (transforms.RandomPhotometricDistort(p=1.0), None), + (transforms.RandomPosterize(bits=4, p=1.0), None), + (transforms.RandomSolarize(threshold=0.5, p=1.0), None), + (transforms.CenterCrop([16, 16]), None), + (transforms.ElasticTransform(sigma=1.0), None), + (transforms.Pad(4), None), + (transforms.RandomAffine(degrees=30.0), None), + (transforms.RandomCrop([16, 16], pad_if_needed=True), None), + (transforms.RandomHorizontalFlip(p=1.0), None), + (transforms.RandomPerspective(p=1.0), None), + (transforms.RandomResize(min_size=10, max_size=20), None), + (transforms.RandomResizedCrop([16, 16]), None), + (transforms.RandomRotation(degrees=30), None), + (transforms.RandomShortestSize(min_size=10), None), + (transforms.RandomVerticalFlip(p=1.0), None), + (transforms.RandomZoomOut(p=1.0), None), + (transforms.Resize([16, 16], antialias=True), None), + (transforms.ScaleJitter((16, 16)), None), + (transforms.ClampBoundingBoxes(), None), + (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), + (transforms.ConvertDtype(), None), + (transforms.GaussianBlur(kernel_size=3), None), + ( + transforms.LinearTransformation( + # These are just dummy values that will be filled by the adapter. We can't define them upfront, + # because for we neither know the spatial size nor the device at this point + transformation_matrix=torch.empty((1, 1)), + mean_vector=torch.empty((1,)), + ), + linear_transformation_adapter, + ), + (transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter), + (transforms.ToDtype(torch.float64), None), + (transforms.UniformTemporalSubsample(num_samples=2), None), + ], + ids=lambda transform: type(transform).__name__, ) - def test_common(self, transform, input): - transform(input) + @pytest.mark.parametrize("container_type", [dict, list, tuple]) + @pytest.mark.parametrize( + "image_or_video", + [ + make_image(), + make_video(), + next(make_pil_images(color_spaces=["RGB"])), + next(make_vanilla_tensor_images()), + ], + ) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_common(self, transform, adapter, container_type, image_or_video, device): + spatial_size = F.get_spatial_size(image_or_video) + input = dict( + image_or_video=image_or_video, + image_datapoint=make_image(size=spatial_size), + video_datapoint=make_video(size=spatial_size), + image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), + bounding_box_xyxy=make_bounding_box( + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,) + ), + bounding_box_xywh=make_bounding_box( + format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,) + ), + bounding_box_cxcywh=make_bounding_box( + format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,) + ), + bounding_box_degenerate_xyxy=datapoints.BoundingBox( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [2, 0, 1, 1], # x1 > x2, y1 < y2 + [0, 2, 1, 1], # x1 < x2, y1 > y2 + [2, 2, 1, 1], # x1 > x2, y1 > y2 + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ), + bounding_box_degenerate_xywh=datapoints.BoundingBox( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [0, 0, 1, -1], # negative height + [0, 0, -1, 1], # negative width + [0, 0, -1, -1], # negative height and width + ], + format=datapoints.BoundingBoxFormat.XYWH, + spatial_size=spatial_size, + ), + bounding_box_degenerate_cxcywh=datapoints.BoundingBox( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [0, 0, 1, -1], # negative height + [0, 0, -1, 1], # negative width + [0, 0, -1, -1], # negative height and width + ], + format=datapoints.BoundingBoxFormat.CXCYWH, + spatial_size=spatial_size, + ), + detection_mask=make_detection_mask(size=spatial_size), + segmentation_mask=make_segmentation_mask(size=spatial_size), + int=0, + float=0.0, + bool=True, + none=None, + str="str", + path=pathlib.Path.cwd(), + tensor=torch.empty(5), + array=np.empty(5), + ) + if adapter is not None: + input = adapter(transform, input, device) + + if container_type in {tuple, list}: + input = container_type(input.values()) + + input_flat, input_spec = tree_flatten(input) + input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat] + input = tree_unflatten(input_flat, input_spec) + + output = transform(input) + output_flat, output_spec = tree_flatten(output) + + assert output_spec == input_spec + + for output_item, input_item in zip(output_flat, input_flat): + if check_type(input_item, transform._transformed_types): + assert type(output_item) is type(input_item) + else: + assert output_item is input_item @parametrize( [ From e96ab2680cdd52111b7b25ec28aa412f14ff6c80 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 22:48:35 +0100 Subject: [PATCH 2/6] cleanup --- test/test_prototype_transforms.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index e10fe284983..86aa498570f 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -21,7 +21,6 @@ make_image, make_images, make_label, - make_masks, make_one_hot_labels, make_segmentation_mask, make_video, @@ -69,32 +68,6 @@ def parametrize(transforms_with_inputs): ) -def parametrize_from_transforms(*transforms): - transforms_with_inputs = [] - for transform in transforms: - for creation_fn in [ - make_images, - make_bounding_boxes, - make_one_hot_labels, - make_vanilla_tensor_images, - make_pil_images, - make_masks, - make_videos, - ]: - inputs = list(creation_fn()) - try: - output = transform(inputs[0]) - except Exception: - continue - else: - if output is inputs[0]: - continue - - transforms_with_inputs.append((transform, inputs)) - - return parametrize(transforms_with_inputs) - - def auto_augment_adapter(transform, input, device): adapted_input = {} image_or_video_found = False From 5c08d7fb2ed564b04fbfb9917c5bdaff83239d18 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 23:01:25 +0100 Subject: [PATCH 3/6] add plain object --- test/test_prototype_transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 86aa498570f..fc16305c008 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -229,6 +229,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device none=None, str="str", path=pathlib.Path.cwd(), + object=object(), tensor=torch.empty(5), array=np.empty(5), ) From dc4c8ee702ac39220339b27a6a5ec5211d5e15d8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 23:17:47 +0100 Subject: [PATCH 4/6] fix chw query for LinearTransform --- test/test_prototype_transforms.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index fc16305c008..b3671e2a2b0 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -85,7 +85,14 @@ def auto_augment_adapter(transform, input, device): def linear_transformation_adapter(transform, input, device): - c, h, w = query_chw(input.values()) + flat_inputs = list(input.values()) + c, h, w = query_chw( + [ + item + for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs)) + if needs_transform + ] + ) num_elements = c * h * w transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device) transform.mean_vector = torch.randn((num_elements,), device=device) From aff68ed31dc8574dcce0c5012fcb6c8a7218c8da Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 12:12:46 +0100 Subject: [PATCH 5/6] imrpove check strictness --- test/test_prototype_transforms.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b3671e2a2b0..afb1980096a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -255,8 +255,10 @@ def test_common(self, transform, adapter, container_type, image_or_video, device assert output_spec == input_spec - for output_item, input_item in zip(output_flat, input_flat): - if check_type(input_item, transform._transformed_types): + for output_item, input_item, should_be_transformed in zip( + output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat) + ): + if should_be_transformed: assert type(output_item) is type(input_item) else: assert output_item is input_item From ed02db71bc1ca292d40773d5570a81d5d0f4390c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 14:49:41 +0100 Subject: [PATCH 6/6] fix flakiness --- test/test_prototype_transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index afb1980096a..7030d2d1b2e 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -145,7 +145,7 @@ class TestSmoke: (transforms.RandomVerticalFlip(p=1.0), None), (transforms.RandomZoomOut(p=1.0), None), (transforms.Resize([16, 16], antialias=True), None), - (transforms.ScaleJitter((16, 16)), None), + (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None), (transforms.ClampBoundingBoxes(), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), (transforms.ConvertDtype(), None), @@ -250,6 +250,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat] input = tree_unflatten(input_flat, input_spec) + torch.manual_seed(0) output = transform(input) output_flat, output_spec = tree_flatten(output)