diff --git a/.github/workflows/prototype-tests.yml b/.github/workflows/prototype-tests.yml index daff383b097..d3a9fbf1e3e 100644 --- a/.github/workflows/prototype-tests.yml +++ b/.github/workflows/prototype-tests.yml @@ -43,14 +43,14 @@ jobs: id: setup run: exit 0 - - name: Run prototype features tests + - name: Run prototype datapoints tests shell: bash run: | pytest \ --durations=20 \ - --cov=torchvision/prototype/features \ + --cov=torchvision/prototype/datapoints \ --cov-report=term-missing \ - test/test_prototype_features*.py + test/test_prototype_datapoints*.py - name: Run prototype transforms tests if: success() || ( failure() && steps.setup.conclusion == 'success' ) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 61cf065e4d4..18664eb0945 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -15,7 +15,7 @@ from datasets_utils import combinations_grid from torch.nn.functional import one_hot from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.transforms.functional_tensor import _max_value as get_max_value @@ -238,7 +238,7 @@ def load(self, device): @dataclasses.dataclass class ImageLoader(TensorLoader): - color_space: features.ColorSpace + color_space: datapoints.ColorSpace spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) @@ -248,10 +248,10 @@ def __post_init__(self): NUM_CHANNELS_MAP = { - features.ColorSpace.GRAY: 1, - features.ColorSpace.GRAY_ALPHA: 2, - features.ColorSpace.RGB: 3, - features.ColorSpace.RGB_ALPHA: 4, + datapoints.ColorSpace.GRAY: 1, + datapoints.ColorSpace.GRAY_ALPHA: 2, + datapoints.ColorSpace.RGB: 3, + datapoints.ColorSpace.RGB_ALPHA: 4, } @@ -265,7 +265,7 @@ def get_num_channels(color_space): def make_image_loader( size="random", *, - color_space=features.ColorSpace.RGB, + color_space=datapoints.ColorSpace.RGB, extra_dims=(), dtype=torch.float32, constant_alpha=True, @@ -276,9 +276,9 @@ def make_image_loader( def fn(shape, dtype, device): max_value = get_max_value(dtype) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) - if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha: + if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha: data[..., -1, :, :] = max_value - return features.Image(data, color_space=color_space) + return datapoints.Image(data, color_space=color_space) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) @@ -290,10 +290,10 @@ def make_image_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, + datapoints.ColorSpace.GRAY, + datapoints.ColorSpace.GRAY_ALPHA, + datapoints.ColorSpace.RGB, + datapoints.ColorSpace.RGB_ALPHA, ), extra_dims=DEFAULT_EXTRA_DIMS, dtypes=(torch.float32, torch.uint8), @@ -306,7 +306,7 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space=features.ColorSpace.RGB, dtype=torch.uint8): +def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) @@ -318,24 +318,24 @@ def fn(shape, dtype, device): .resize((width, height)) .convert( { - features.ColorSpace.GRAY: "L", - features.ColorSpace.GRAY_ALPHA: "LA", - features.ColorSpace.RGB: "RGB", - features.ColorSpace.RGB_ALPHA: "RGBA", + datapoints.ColorSpace.GRAY: "L", + datapoints.ColorSpace.GRAY_ALPHA: "LA", + datapoints.ColorSpace.RGB: "RGB", + datapoints.ColorSpace.RGB_ALPHA: "RGBA", }[color_space] ) ) image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) - return features.Image(image_tensor, color_space=color_space) + return datapoints.Image(image_tensor, color_space=color_space) return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) def make_image_loaders_for_interpolation( sizes=((233, 147),), - color_spaces=(features.ColorSpace.RGB,), + color_spaces=(datapoints.ColorSpace.RGB,), dtypes=(torch.uint8,), ): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): @@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation( @dataclasses.dataclass class BoundingBoxLoader(TensorLoader): - format: features.BoundingBoxFormat + format: datapoints.BoundingBoxFormat spatial_size: Tuple[int, int] @@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32): if isinstance(format, str): - format = features.BoundingBoxFormat[format] + format = datapoints.BoundingBoxFormat[format] if format not in { - features.BoundingBoxFormat.XYXY, - features.BoundingBoxFormat.XYWH, - features.BoundingBoxFormat.CXCYWH, + datapoints.BoundingBoxFormat.XYXY, + datapoints.BoundingBoxFormat.XYWH, + datapoints.BoundingBoxFormat.CXCYWH, }: raise pytest.UsageError(f"Can't make bounding box in format {format}") @@ -378,19 +378,19 @@ def fn(shape, dtype, device): raise pytest.UsageError() if any(dim == 0 for dim in extra_dims): - return features.BoundingBox( + return datapoints.BoundingBox( torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size ) height, width = spatial_size - if format == features.BoundingBoxFormat.XYXY: + if format == datapoints.BoundingBoxFormat.XYXY: x1 = torch.randint(0, width // 2, extra_dims) y1 = torch.randint(0, height // 2, extra_dims) x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 parts = (x1, y1, x2, y2) - elif format == features.BoundingBoxFormat.XYWH: + elif format == datapoints.BoundingBoxFormat.XYWH: x = torch.randint(0, width // 2, extra_dims) y = torch.randint(0, height // 2, extra_dims) w = randint_with_tensor_bounds(1, width - x) @@ -403,7 +403,7 @@ def fn(shape, dtype, device): h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) parts = (cx, cy, w, h) - return features.BoundingBox( + return datapoints.BoundingBox( torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size ) @@ -416,7 +416,7 @@ def fn(shape, dtype, device): def make_bounding_box_loaders( *, extra_dims=DEFAULT_EXTRA_DIMS, - formats=tuple(features.BoundingBoxFormat), + formats=tuple(datapoints.BoundingBoxFormat), spatial_size="random", dtypes=(torch.float32, torch.int64), ): @@ -456,7 +456,7 @@ def fn(shape, dtype, device): # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) - return features.Label(data, categories=categories) + return datapoints.Label(data, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) @@ -480,7 +480,7 @@ def fn(shape, dtype, device): # since `one_hot` only supports int64 label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) data = one_hot(label, num_classes=num_categories).to(dtype) - return features.OneHotLabel(data, categories=categories) + return datapoints.OneHotLabel(data, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) @@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim def fn(shape, dtype, device): data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) - return features.Mask(data) + return datapoints.Mask(data) return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype) @@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext def fn(shape, dtype, device): data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device) - return features.Mask(data) + return datapoints.Mask(data) return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype) @@ -583,7 +583,7 @@ class VideoLoader(ImageLoader): def make_video_loader( size="random", *, - color_space=features.ColorSpace.RGB, + color_space=datapoints.ColorSpace.RGB, num_frames="random", extra_dims=(), dtype=torch.uint8, @@ -593,7 +593,7 @@ def make_video_loader( def fn(shape, dtype, device): video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) - return features.Video(video, color_space=color_space) + return datapoints.Video(video, color_space=color_space) return VideoLoader( fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space @@ -607,8 +607,8 @@ def make_video_loaders( *, sizes=DEFAULT_SPATIAL_SIZES, color_spaces=( - features.ColorSpace.GRAY, - features.ColorSpace.RGB, + datapoints.ColorSpace.GRAY, + datapoints.ColorSpace.RGB, ), num_frames=(1, 0, "random"), extra_dims=DEFAULT_EXTRA_DIMS, diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 8a9f5148e2f..b92278fef56 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -4,7 +4,7 @@ import torchvision.prototype.transforms.functional as F from prototype_common_utils import InfoBase, TestMark from prototype_transforms_kernel_infos import KERNEL_INFOS -from torchvision.prototype import features +from torchvision.prototype import datapoints __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] @@ -139,20 +139,20 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.horizontal_flip, kernels={ - features.Image: F.horizontal_flip_image_tensor, - features.Video: F.horizontal_flip_video, - features.BoundingBox: F.horizontal_flip_bounding_box, - features.Mask: F.horizontal_flip_mask, + datapoints.Image: F.horizontal_flip_image_tensor, + datapoints.Video: F.horizontal_flip_video, + datapoints.BoundingBox: F.horizontal_flip_bounding_box, + datapoints.Mask: F.horizontal_flip_mask, }, pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"), ), DispatcherInfo( F.resize, kernels={ - features.Image: F.resize_image_tensor, - features.Video: F.resize_video, - features.BoundingBox: F.resize_bounding_box, - features.Mask: F.resize_mask, + datapoints.Image: F.resize_image_tensor, + datapoints.Video: F.resize_video, + datapoints.BoundingBox: F.resize_bounding_box, + datapoints.Mask: F.resize_mask, }, pil_kernel_info=PILKernelInfo(F.resize_image_pil), test_marks=[ @@ -162,10 +162,10 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.affine, kernels={ - features.Image: F.affine_image_tensor, - features.Video: F.affine_video, - features.BoundingBox: F.affine_bounding_box, - features.Mask: F.affine_mask, + datapoints.Image: F.affine_image_tensor, + datapoints.Video: F.affine_video, + datapoints.BoundingBox: F.affine_bounding_box, + datapoints.Mask: F.affine_mask, }, pil_kernel_info=PILKernelInfo(F.affine_image_pil), test_marks=[ @@ -179,20 +179,20 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.vertical_flip, kernels={ - features.Image: F.vertical_flip_image_tensor, - features.Video: F.vertical_flip_video, - features.BoundingBox: F.vertical_flip_bounding_box, - features.Mask: F.vertical_flip_mask, + datapoints.Image: F.vertical_flip_image_tensor, + datapoints.Video: F.vertical_flip_video, + datapoints.BoundingBox: F.vertical_flip_bounding_box, + datapoints.Mask: F.vertical_flip_mask, }, pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"), ), DispatcherInfo( F.rotate, kernels={ - features.Image: F.rotate_image_tensor, - features.Video: F.rotate_video, - features.BoundingBox: F.rotate_bounding_box, - features.Mask: F.rotate_mask, + datapoints.Image: F.rotate_image_tensor, + datapoints.Video: F.rotate_video, + datapoints.BoundingBox: F.rotate_bounding_box, + datapoints.Mask: F.rotate_mask, }, pil_kernel_info=PILKernelInfo(F.rotate_image_pil), test_marks=[ @@ -204,30 +204,30 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.crop, kernels={ - features.Image: F.crop_image_tensor, - features.Video: F.crop_video, - features.BoundingBox: F.crop_bounding_box, - features.Mask: F.crop_mask, + datapoints.Image: F.crop_image_tensor, + datapoints.Video: F.crop_video, + datapoints.BoundingBox: F.crop_bounding_box, + datapoints.Mask: F.crop_mask, }, pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), ), DispatcherInfo( F.resized_crop, kernels={ - features.Image: F.resized_crop_image_tensor, - features.Video: F.resized_crop_video, - features.BoundingBox: F.resized_crop_bounding_box, - features.Mask: F.resized_crop_mask, + datapoints.Image: F.resized_crop_image_tensor, + datapoints.Video: F.resized_crop_video, + datapoints.BoundingBox: F.resized_crop_bounding_box, + datapoints.Mask: F.resized_crop_mask, }, pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), ), DispatcherInfo( F.pad, kernels={ - features.Image: F.pad_image_tensor, - features.Video: F.pad_video, - features.BoundingBox: F.pad_bounding_box, - features.Mask: F.pad_mask, + datapoints.Image: F.pad_image_tensor, + datapoints.Video: F.pad_video, + datapoints.BoundingBox: F.pad_bounding_box, + datapoints.Mask: F.pad_mask, }, pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), test_marks=[ @@ -251,10 +251,10 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.perspective, kernels={ - features.Image: F.perspective_image_tensor, - features.Video: F.perspective_video, - features.BoundingBox: F.perspective_bounding_box, - features.Mask: F.perspective_mask, + datapoints.Image: F.perspective_image_tensor, + datapoints.Video: F.perspective_video, + datapoints.BoundingBox: F.perspective_bounding_box, + datapoints.Mask: F.perspective_mask, }, pil_kernel_info=PILKernelInfo(F.perspective_image_pil), test_marks=[ @@ -264,20 +264,20 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.elastic, kernels={ - features.Image: F.elastic_image_tensor, - features.Video: F.elastic_video, - features.BoundingBox: F.elastic_bounding_box, - features.Mask: F.elastic_mask, + datapoints.Image: F.elastic_image_tensor, + datapoints.Video: F.elastic_video, + datapoints.BoundingBox: F.elastic_bounding_box, + datapoints.Mask: F.elastic_mask, }, pil_kernel_info=PILKernelInfo(F.elastic_image_pil), ), DispatcherInfo( F.center_crop, kernels={ - features.Image: F.center_crop_image_tensor, - features.Video: F.center_crop_video, - features.BoundingBox: F.center_crop_bounding_box, - features.Mask: F.center_crop_mask, + datapoints.Image: F.center_crop_image_tensor, + datapoints.Video: F.center_crop_video, + datapoints.BoundingBox: F.center_crop_bounding_box, + datapoints.Mask: F.center_crop_mask, }, pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), test_marks=[ @@ -287,8 +287,8 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.gaussian_blur, kernels={ - features.Image: F.gaussian_blur_image_tensor, - features.Video: F.gaussian_blur_video, + datapoints.Image: F.gaussian_blur_image_tensor, + datapoints.Video: F.gaussian_blur_video, }, pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), test_marks=[ @@ -299,56 +299,56 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.equalize, kernels={ - features.Image: F.equalize_image_tensor, - features.Video: F.equalize_video, + datapoints.Image: F.equalize_image_tensor, + datapoints.Video: F.equalize_video, }, pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"), ), DispatcherInfo( F.invert, kernels={ - features.Image: F.invert_image_tensor, - features.Video: F.invert_video, + datapoints.Image: F.invert_image_tensor, + datapoints.Video: F.invert_video, }, pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"), ), DispatcherInfo( F.posterize, kernels={ - features.Image: F.posterize_image_tensor, - features.Video: F.posterize_video, + datapoints.Image: F.posterize_image_tensor, + datapoints.Video: F.posterize_video, }, pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"), ), DispatcherInfo( F.solarize, kernels={ - features.Image: F.solarize_image_tensor, - features.Video: F.solarize_video, + datapoints.Image: F.solarize_image_tensor, + datapoints.Video: F.solarize_video, }, pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"), ), DispatcherInfo( F.autocontrast, kernels={ - features.Image: F.autocontrast_image_tensor, - features.Video: F.autocontrast_video, + datapoints.Image: F.autocontrast_image_tensor, + datapoints.Video: F.autocontrast_video, }, pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"), ), DispatcherInfo( F.adjust_sharpness, kernels={ - features.Image: F.adjust_sharpness_image_tensor, - features.Video: F.adjust_sharpness_video, + datapoints.Image: F.adjust_sharpness_image_tensor, + datapoints.Video: F.adjust_sharpness_video, }, pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), ), DispatcherInfo( F.erase, kernels={ - features.Image: F.erase_image_tensor, - features.Video: F.erase_video, + datapoints.Image: F.erase_image_tensor, + datapoints.Video: F.erase_video, }, pil_kernel_info=PILKernelInfo(F.erase_image_pil), test_marks=[ @@ -358,48 +358,48 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.adjust_brightness, kernels={ - features.Image: F.adjust_brightness_image_tensor, - features.Video: F.adjust_brightness_video, + datapoints.Image: F.adjust_brightness_image_tensor, + datapoints.Video: F.adjust_brightness_video, }, pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"), ), DispatcherInfo( F.adjust_contrast, kernels={ - features.Image: F.adjust_contrast_image_tensor, - features.Video: F.adjust_contrast_video, + datapoints.Image: F.adjust_contrast_image_tensor, + datapoints.Video: F.adjust_contrast_video, }, pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), ), DispatcherInfo( F.adjust_gamma, kernels={ - features.Image: F.adjust_gamma_image_tensor, - features.Video: F.adjust_gamma_video, + datapoints.Image: F.adjust_gamma_image_tensor, + datapoints.Video: F.adjust_gamma_video, }, pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), ), DispatcherInfo( F.adjust_hue, kernels={ - features.Image: F.adjust_hue_image_tensor, - features.Video: F.adjust_hue_video, + datapoints.Image: F.adjust_hue_image_tensor, + datapoints.Video: F.adjust_hue_video, }, pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), ), DispatcherInfo( F.adjust_saturation, kernels={ - features.Image: F.adjust_saturation_image_tensor, - features.Video: F.adjust_saturation_video, + datapoints.Image: F.adjust_saturation_image_tensor, + datapoints.Video: F.adjust_saturation_video, }, pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), ), DispatcherInfo( F.five_crop, kernels={ - features.Image: F.five_crop_image_tensor, - features.Video: F.five_crop_video, + datapoints.Image: F.five_crop_image_tensor, + datapoints.Video: F.five_crop_video, }, pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), test_marks=[ @@ -410,8 +410,8 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.ten_crop, kernels={ - features.Image: F.ten_crop_image_tensor, - features.Video: F.ten_crop_video, + datapoints.Image: F.ten_crop_image_tensor, + datapoints.Video: F.ten_crop_video, }, test_marks=[ xfail_jit_python_scalar_arg("size"), @@ -422,8 +422,8 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.normalize, kernels={ - features.Image: F.normalize_image_tensor, - features.Video: F.normalize_video, + datapoints.Image: F.normalize_image_tensor, + datapoints.Video: F.normalize_video, }, test_marks=[ skip_dispatch_feature, @@ -434,8 +434,8 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.convert_dtype, kernels={ - features.Image: F.convert_dtype_image_tensor, - features.Video: F.convert_dtype_video, + datapoints.Image: F.convert_dtype_image_tensor, + datapoints.Video: F.convert_dtype_video, }, test_marks=[ skip_dispatch_feature, @@ -444,7 +444,7 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.uniform_temporal_subsample, kernels={ - features.Video: F.uniform_temporal_subsample_video, + datapoints.Video: F.uniform_temporal_subsample_video, }, test_marks=[ skip_dispatch_feature, diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index b5618c53dcd..8849365ea85 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -26,7 +26,7 @@ TestMark, ) from torch.utils._pytree import tree_map -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding __all__ = ["KernelInfo", "KERNEL_INFOS"] @@ -176,7 +176,7 @@ def reference_inputs_horizontal_flip_image_tensor(): def sample_inputs_horizontal_flip_bounding_box(): for bounding_box_loader in make_bounding_box_loaders( - formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] + formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32] ): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size @@ -258,13 +258,13 @@ def _get_resize_sizes(spatial_size): def sample_inputs_resize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] + sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] ): for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs(image_loader, size=size) for image_loader, interpolation in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB]), + make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB]), [ F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR, @@ -468,7 +468,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): def sample_inputs_affine_image_tensor(): make_affine_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] ) for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): @@ -499,7 +499,7 @@ def reference_inputs_affine_image_tensor(): def sample_inputs_affine_bounding_box(): for bounding_box_loader, affine_params in itertools.product( - make_bounding_box_loaders(formats=[features.BoundingBoxFormat.XYXY]), _DIVERSE_AFFINE_PARAMS + make_bounding_box_loaders(formats=[datapoints.BoundingBoxFormat.XYXY]), _DIVERSE_AFFINE_PARAMS ): yield ArgsKwargs( bounding_box_loader, @@ -537,7 +537,7 @@ def transform(bbox, affine_matrix_, format_): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 in_dtype = bbox.dtype bbox_xyxy = F.convert_format_bounding_box( - bbox.float(), old_format=format_, new_format=features.BoundingBoxFormat.XYXY, inplace=True + bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True ) points = np.array( [ @@ -557,7 +557,7 @@ def transform(bbox, affine_matrix_, format_): ], ) out_bbox = F.convert_format_bounding_box( - out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format_, inplace=True + out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True ) return out_bbox.to(dtype=in_dtype) @@ -652,7 +652,7 @@ def sample_inputs_affine_video(): def sample_inputs_convert_format_bounding_box(): - formats = list(features.BoundingBoxFormat) + formats = list(datapoints.BoundingBoxFormat) for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) @@ -681,7 +681,7 @@ def reference_inputs_convert_format_bounding_box(): def sample_inputs_convert_color_space_image_tensor(): color_spaces = sorted( - set(features.ColorSpace) - {features.ColorSpace.OTHER}, key=lambda color_space: color_space.value + set(datapoints.ColorSpace) - {datapoints.ColorSpace.OTHER}, key=lambda color_space: color_space.value ) for old_color_space, new_color_space in cycle_over(color_spaces): @@ -697,7 +697,7 @@ def sample_inputs_convert_color_space_image_tensor(): @pil_reference_wrapper def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space): - color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode) + color_space_pil = datapoints.ColorSpace.from_pil_mode(image_pil.mode) if color_space_pil != old_color_space: raise pytest.UsageError( f"Converting the tensor image into an PIL image changed the colorspace " @@ -715,7 +715,7 @@ def reference_inputs_convert_color_space_image_tensor(): def sample_inputs_convert_color_space_video(): - color_spaces = [features.ColorSpace.GRAY, features.ColorSpace.RGB] + color_spaces = [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB] for old_color_space, new_color_space in cycle_over(color_spaces): for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]): @@ -754,7 +754,7 @@ def reference_inputs_vertical_flip_image_tensor(): def sample_inputs_vertical_flip_bounding_box(): for bounding_box_loader in make_bounding_box_loaders( - formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] + formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32] ): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size @@ -817,7 +817,7 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): def sample_inputs_rotate_image_tensor(): make_rotate_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] ) for image_loader in make_rotate_image_loaders(): @@ -899,7 +899,7 @@ def sample_inputs_rotate_video(): def sample_inputs_crop_image_tensor(): for image_loader, params in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), [ dict(top=4, left=3, height=7, width=8), dict(top=-1, left=3, height=7, width=8), @@ -1085,7 +1085,7 @@ def sample_inputs_resized_crop_video(): def sample_inputs_pad_image_tensor(): make_pad_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] + make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] ) for image_loader, padding in itertools.product( @@ -1401,7 +1401,7 @@ def sample_inputs_elastic_video(): def sample_inputs_center_crop_image_tensor(): for image_loader, output_size in itertools.product( - make_image_loaders(sizes=[(16, 17)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), [ # valid `output_size` types for which cropping is applied to both dimensions *[5, (4,), (2, 3), [6], [3, 2]], @@ -1488,7 +1488,7 @@ def sample_inputs_center_crop_video(): def sample_inputs_gaussian_blur_image_tensor(): make_gaussian_blur_image_loaders = functools.partial( - make_image_loaders, sizes=[(7, 33)], color_spaces=[features.ColorSpace.RGB] + make_image_loaders, sizes=[(7, 33)], color_spaces=[datapoints.ColorSpace.RGB] ) for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): @@ -1527,7 +1527,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader) @@ -1555,7 +1555,7 @@ def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): spatial_size = (256, 256) for dtype, color_space, fn in itertools.product( [torch.uint8], - [features.ColorSpace.GRAY, features.ColorSpace.RGB], + [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB], [ lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.full( @@ -1611,14 +1611,14 @@ def sample_inputs_equalize_video(): def sample_inputs_invert_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader) def reference_inputs_invert_image_tensor(): for image_loader in make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader) @@ -1651,7 +1651,7 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) @@ -1659,7 +1659,7 @@ def sample_inputs_posterize_image_tensor(): def reference_inputs_posterize_image_tensor(): for image_loader, bits in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _POSTERIZE_BITS, ): @@ -1698,14 +1698,14 @@ def _get_solarize_thresholds(dtype): def sample_inputs_solarize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) def reference_inputs_solarize_image_tensor(): for image_loader in make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ): for threshold in _get_solarize_thresholds(image_loader.dtype): yield ArgsKwargs(image_loader, threshold=threshold) @@ -1741,14 +1741,14 @@ def sample_inputs_solarize_video(): def sample_inputs_autocontrast_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader) def reference_inputs_autocontrast_image_tensor(): for image_loader in make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ): yield ArgsKwargs(image_loader) @@ -1785,7 +1785,7 @@ def sample_inputs_autocontrast_video(): def sample_inputs_adjust_sharpness_image_tensor(): for image_loader in make_image_loaders( sizes=["random", (2, 2)], - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), ): yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) @@ -1793,7 +1793,7 @@ def sample_inputs_adjust_sharpness_image_tensor(): def reference_inputs_adjust_sharpness_image_tensor(): for image_loader, sharpness_factor in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_SHARPNESS_FACTORS, ): @@ -1859,7 +1859,7 @@ def sample_inputs_erase_video(): def sample_inputs_adjust_brightness_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) @@ -1867,7 +1867,7 @@ def sample_inputs_adjust_brightness_image_tensor(): def reference_inputs_adjust_brightness_image_tensor(): for image_loader, brightness_factor in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_BRIGHTNESS_FACTORS, ): @@ -1903,7 +1903,7 @@ def sample_inputs_adjust_brightness_video(): def sample_inputs_adjust_contrast_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) @@ -1911,7 +1911,7 @@ def sample_inputs_adjust_contrast_image_tensor(): def reference_inputs_adjust_contrast_image_tensor(): for image_loader, contrast_factor in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_CONTRAST_FACTORS, ): @@ -1953,7 +1953,7 @@ def sample_inputs_adjust_contrast_video(): def sample_inputs_adjust_gamma_image_tensor(): gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) @@ -1961,7 +1961,7 @@ def sample_inputs_adjust_gamma_image_tensor(): def reference_inputs_adjust_gamma_image_tensor(): for image_loader, (gamma, gain) in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_GAMMA_GAMMAS_GAINS, ): @@ -2001,7 +2001,7 @@ def sample_inputs_adjust_gamma_video(): def sample_inputs_adjust_hue_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) @@ -2009,7 +2009,7 @@ def sample_inputs_adjust_hue_image_tensor(): def reference_inputs_adjust_hue_image_tensor(): for image_loader, hue_factor in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_HUE_FACTORS, ): @@ -2047,7 +2047,7 @@ def sample_inputs_adjust_hue_video(): def sample_inputs_adjust_saturation_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) @@ -2055,7 +2055,7 @@ def sample_inputs_adjust_saturation_image_tensor(): def reference_inputs_adjust_saturation_image_tensor(): for image_loader, saturation_factor in itertools.product( make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] + color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] ), _ADJUST_SATURATION_FACTORS, ): @@ -2120,7 +2120,7 @@ def sample_inputs_five_crop_image_tensor(): for size in _FIVE_TEN_CROP_SIZES: for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[features.ColorSpace.RGB], + color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size) @@ -2144,7 +2144,7 @@ def sample_inputs_ten_crop_image_tensor(): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for image_loader in make_image_loaders( sizes=[_get_five_ten_crop_spatial_size(size)], - color_spaces=[features.ColorSpace.RGB], + color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32], ): yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) @@ -2218,7 +2218,7 @@ def wrapper(input_tensor, *other_args, **kwargs): def sample_inputs_normalize_image_tensor(): for image_loader, (mean, std) in itertools.product( - make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), + make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), _NORMALIZE_MEANS_STDS, ): yield ArgsKwargs(image_loader, mean=mean, std=std) @@ -2227,7 +2227,7 @@ def sample_inputs_normalize_image_tensor(): def sample_inputs_normalize_video(): mean, std = _NORMALIZE_MEANS_STDS[0] for video_loader in make_video_loaders( - sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] + sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] ): yield ArgsKwargs(video_loader, mean=mean, std=std) @@ -2260,7 +2260,7 @@ def sample_inputs_convert_dtype_image_tensor(): continue for image_loader in make_image_loaders( - sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype] + sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[input_dtype] ): yield ArgsKwargs(image_loader, dtype=output_dtype) @@ -2388,7 +2388,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): def reference_inputs_uniform_temporal_subsample_video(): - for video_loader in make_video_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=[10]): + for video_loader in make_video_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=[10]): for num_samples in range(1, video_loader.shape[-4] + 1): yield ArgsKwargs(video_loader, num_samples) diff --git a/test/test_prototype_features.py b/test/test_prototype_datapoints.py similarity index 66% rename from test/test_prototype_features.py rename to test/test_prototype_datapoints.py index d2b0d2e632c..d036b5db1de 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_datapoints.py @@ -1,36 +1,36 @@ import pytest import torch -from torchvision.prototype import features +from torchvision.prototype import datapoints def test_isinstance(): assert isinstance( - features.Label([0, 1, 0], categories=["foo", "bar"]), + datapoints.Label([0, 1, 0], categories=["foo", "bar"]), torch.Tensor, ) def test_wrapping_no_copy(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) assert label.data_ptr() == tensor.data_ptr() def test_to_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) label_to = label.to(torch.int32) - assert type(label_to) is features.Label + assert type(label_to) is datapoints.Label assert label_to.dtype is torch.int32 assert label_to.categories is label.categories def test_to_feature_reference(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) + label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) tensor_to = tensor.to(label) @@ -40,31 +40,31 @@ def test_to_feature_reference(): def test_clone_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) label_clone = label.clone() - assert type(label_clone) is features.Label + assert type(label_clone) is datapoints.Label assert label_clone.data_ptr() != label.data_ptr() assert label_clone.categories is label.categories def test_requires_grad__wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.float32) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) assert not label.requires_grad label_requires_grad = label.requires_grad_(True) - assert type(label_requires_grad) is features.Label + assert type(label_requires_grad) is datapoints.Label assert label.requires_grad assert label_requires_grad.requires_grad def test_other_op_no_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) # any operation besides .to() and .clone() will do here output = label * 2 @@ -82,32 +82,32 @@ def test_other_op_no_wrapping(): ) def test_no_tensor_output_op_no_wrapping(op): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) output = op(label) - assert type(output) is not features.Label + assert type(output) is not datapoints.Label def test_inplace_op_no_wrapping(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) output = label.add_(0) assert type(output) is torch.Tensor - assert type(label) is features.Label + assert type(label) is datapoints.Label def test_wrap_like(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label = datapoints.Label(tensor, categories=["foo", "bar"]) # any operation besides .to() and .clone() will do here output = label * 2 - label_new = features.Label.wrap_like(label, output) + label_new = datapoints.Label.wrap_like(label, output) - assert type(label_new) is features.Label + assert type(label_new) is datapoints.Label assert label_new.data_ptr() == output.data_ptr() assert label_new.categories is label.categories diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py index 7bea05fcef5..25ceaa490e2 100644 --- a/test/test_prototype_datasets_builtin.py +++ b/test/test_prototype_datasets_builtin.py @@ -6,6 +6,8 @@ import pytest import torch + +import torchvision.prototype.transforms.utils from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.utils.data import DataLoader @@ -14,7 +16,7 @@ from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.utils import StreamWrapper from torchvision._utils import sequence_to_str -from torchvision.prototype import datasets, features, transforms +from torchvision.prototype import datapoints, datasets, transforms from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE @@ -130,7 +132,11 @@ def make_msg_and_close(head): def test_no_simple_tensors(self, dataset_mock, config): dataset, _ = dataset_mock.load(config) - simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)} + simple_tensors = { + key + for key, value in next_consume(iter(dataset)).items() + if torchvision.prototype.transforms.utils.is_simple_tensor(value) + } if simple_tensors: raise AssertionError( f"The values of key(s) " @@ -258,7 +264,7 @@ def test_sample_content(self, dataset_mock, config): assert "image" in sample assert "label" in sample - assert isinstance(sample["image"], features.Image) - assert isinstance(sample["label"], features.Label) + assert isinstance(sample["image"], datapoints.Image) + assert isinstance(sample["label"], datapoints.Label) assert sample["image"].shape == (1, 16, 16) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 2544bf29fe6..44474e88887 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -6,6 +6,8 @@ import pytest import torch + +import torchvision.prototype.transforms.utils from common_utils import assert_equal, cpu_and_gpu from prototype_common_utils import ( DEFAULT_EXTRA_DIMS, @@ -22,7 +24,7 @@ make_videos, ) from torchvision.ops.boxes import box_iou -from torchvision.prototype import features, transforms +from torchvision.prototype import datapoints, transforms from torchvision.prototype.transforms.utils import check_type from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image @@ -159,8 +161,8 @@ def test_mixup_cutmix(self, transform, input): itertools.chain.from_iterable( fn( color_spaces=[ - features.ColorSpace.GRAY, - features.ColorSpace.RGB, + datapoints.ColorSpace.GRAY, + datapoints.ColorSpace.RGB, ], dtypes=[torch.uint8], extra_dims=[(), (4,)], @@ -190,7 +192,7 @@ def test_auto_augment(self, transform, input): ( transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), itertools.chain.from_iterable( - fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]) + fn(color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]) for fn in [ make_images, make_vanilla_tensor_images, @@ -237,10 +239,10 @@ def test_random_resized_crop(self, transform, input): ) for old_color_space, new_color_space in itertools.product( [ - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, + datapoints.ColorSpace.GRAY, + datapoints.ColorSpace.GRAY_ALPHA, + datapoints.ColorSpace.RGB, + datapoints.ColorSpace.RGB_ALPHA, ], repeat=2, ) @@ -251,7 +253,7 @@ def test_convert_color_space(self, transform, input): def test_convert_color_space_unsupported_types(self): transform = transforms.ConvertColorSpace( - color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY + color_space=datapoints.ColorSpace.RGB, old_color_space=datapoints.ColorSpace.GRAY ) for inpt in [make_bounding_box(format="XYXY"), make_masks()]: @@ -287,26 +289,26 @@ def test_features_image(self, p): input, expected = self.input_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(features.Image(input)) + actual = transform(datapoints.Image(input)) - assert_equal(features.Image(expected), actual) + assert_equal(datapoints.Image(expected), actual) def test_features_mask(self, p): input, expected = self.input_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(features.Mask(input)) + actual = transform(datapoints.Mask(input)) - assert_equal(features.Mask(expected), actual) + assert_equal(datapoints.Mask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) + input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) transform = transforms.RandomHorizontalFlip(p=p) actual = transform(input) expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input - expected = features.BoundingBox.wrap_like(input, expected_image_tensor) + expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.spatial_size == expected.spatial_size @@ -340,26 +342,26 @@ def test_features_image(self, p): input, expected = self.input_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(features.Image(input)) + actual = transform(datapoints.Image(input)) - assert_equal(features.Image(expected), actual) + assert_equal(datapoints.Image(expected), actual) def test_features_mask(self, p): input, expected = self.input_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(features.Mask(input)) + actual = transform(datapoints.Mask(input)) - assert_equal(features.Mask(expected), actual) + assert_equal(datapoints.Mask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) + input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) transform = transforms.RandomVerticalFlip(p=p) actual = transform(input) expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input - expected = features.BoundingBox.wrap_like(input, expected_image_tensor) + expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.spatial_size == expected.spatial_size @@ -386,7 +388,7 @@ def test__transform(self, padding, fill, padding_mode, mocker): transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) _ = transform(inpt) fill = transforms._utils._convert_fill_arg(fill) @@ -394,13 +396,13 @@ def test__transform(self, padding, fill, padding_mode, mocker): padding = list(padding) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) + @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}]) def test__transform_image_mask(self, fill, mocker): transform = transforms.Pad(1, fill=fill, padding_mode="constant") fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - image = features.Image(torch.rand(3, 32, 32)) - mask = features.Mask(torch.randint(0, 5, size=(32, 32))) + image = datapoints.Image(torch.rand(3, 32, 32)) + mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) inpt = [image, mask] _ = transform(inpt) @@ -436,7 +438,7 @@ def test_assertions(self): def test__get_params(self, fill, side_range, mocker): transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) h, w = image.spatial_size = (24, 32) params = transform._get_params([image]) @@ -450,7 +452,7 @@ def test__get_params(self, fill, side_range, mocker): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) def test__transform(self, fill, side_range, mocker): - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -469,13 +471,13 @@ def test__transform(self, fill, side_range, mocker): fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill) - @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) + @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}]) def test__transform_image_mask(self, fill, mocker): transform = transforms.RandomZoomOut(fill=fill, p=1.0) fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - image = features.Image(torch.rand(3, 32, 32)) - mask = features.Mask(torch.randint(0, 5, size=(32, 32))) + image = datapoints.Image(torch.rand(3, 32, 32)) + mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) inpt = [image, mask] torch.manual_seed(12) @@ -547,7 +549,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): assert transform.degrees == [float(-degrees), float(degrees)] fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users # Otherwise, we can mock transform._get_params @@ -563,10 +565,10 @@ def test__transform(self, degrees, expand, fill, center, mocker): @pytest.mark.parametrize("expand", [False, True]) def test_boundingbox_spatial_size(self, angle, expand): # Specific test for BoundingBox.rotate - bbox = features.BoundingBox( - torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, spatial_size=(32, 32) + bbox = datapoints.BoundingBox( + torch.tensor([1, 2, 3, 4]), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(32, 32) ) - img = features.Image(torch.rand(1, 3, 32, 32)) + img = datapoints.Image(torch.rand(1, 3, 32, 32)) out_img = img.rotate(angle, expand=expand) out_bbox = bbox.rotate(angle, expand=expand) @@ -619,7 +621,7 @@ def test_assertions(self): @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) def test__get_params(self, degrees, translate, scale, shear, mocker): - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) h, w = image.spatial_size @@ -682,7 +684,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker assert transform.degrees == [float(-degrees), float(degrees)] fn = mocker.patch("torchvision.prototype.transforms.functional.affine") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -718,7 +720,7 @@ def test_assertions(self): @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) @pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)]) def test__get_params(self, padding, pad_if_needed, size, mocker): - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) h, w = image.spatial_size @@ -771,11 +773,11 @@ def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode ) - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (32, 32) - expected = mocker.MagicMock(spec=features.Image) + expected = mocker.MagicMock(spec=datapoints.Image) expected.num_channels = 3 if isinstance(padding, int): expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding) @@ -859,7 +861,7 @@ def test__transform(self, kernel_size, sigma, mocker): assert transform.sigma == [sigma, sigma] fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -891,7 +893,7 @@ def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): transform = transform_cls(p=p, **kwargs) fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) _ = transform(inpt) if p > 0.0: fn.assert_called_once_with(inpt, **kwargs) @@ -910,7 +912,7 @@ def test_assertions(self): def test__get_params(self, mocker): dscale = 0.5 transform = transforms.RandomPerspective(dscale) - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) @@ -927,7 +929,7 @@ def test__transform(self, distortion_scale, mocker): transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute @@ -971,7 +973,7 @@ def test__get_params(self, mocker): alpha = 2.0 sigma = 3.0 transform = transforms.ElasticTransform(alpha, sigma) - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) @@ -1001,7 +1003,7 @@ def test__transform(self, alpha, sigma, mocker): assert transform.sigma == sigma fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") - inpt = mocker.MagicMock(spec=features.Image) + inpt = mocker.MagicMock(spec=datapoints.Image) inpt.num_channels = 3 inpt.spatial_size = (24, 32) @@ -1030,7 +1032,7 @@ def test_assertions(self, mocker): with pytest.raises(ValueError, match="Scale should be between 0 and 1"): transforms.RandomErasing(scale=[-1, 2]) - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) @@ -1041,7 +1043,7 @@ def test_assertions(self, mocker): @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) def test__get_params(self, value, mocker): - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) @@ -1100,7 +1102,7 @@ def test__transform(self, mocker, p): class TestTransform: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test_check_transformed_types(self, inpt_type, mocker): # This test ensures that we correctly handle which types to transform and which to bypass @@ -1118,7 +1120,7 @@ def test_check_transformed_types(self, inpt_type, mocker): class TestToImageTensor: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch( @@ -1129,7 +1131,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImageTensor() transform(inpt) - if inpt_type in (features.BoundingBox, features.Image, str, int): + if inpt_type in (datapoints.BoundingBox, datapoints.Image, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt) @@ -1138,7 +1140,7 @@ def test__transform(self, inpt_type, mocker): class TestToImagePIL: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") @@ -1146,7 +1148,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImagePIL() transform(inpt) - if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int): + if inpt_type in (datapoints.BoundingBox, PIL.Image.Image, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt, mode=transform.mode) @@ -1155,7 +1157,7 @@ def test__transform(self, inpt_type, mocker): class TestToPILImage: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") @@ -1163,7 +1165,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToPILImage() transform(inpt) - if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int): + if inpt_type in (PIL.Image.Image, datapoints.BoundingBox, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt, mode=transform.mode) @@ -1172,7 +1174,7 @@ def test__transform(self, inpt_type, mocker): class TestToTensor: @pytest.mark.parametrize( "inpt_type", - [torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], + [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch("torchvision.transforms.functional.to_tensor") @@ -1181,7 +1183,7 @@ def test__transform(self, inpt_type, mocker): with pytest.warns(UserWarning, match="deprecated and will be removed"): transform = transforms.ToTensor() transform(inpt) - if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int): + if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBox, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt) @@ -1223,10 +1225,10 @@ class TestRandomIoUCrop: @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) def test__get_params(self, device, options, mocker): - image = mocker.MagicMock(spec=features.Image) + image = mocker.MagicMock(spec=datapoints.Image) image.num_channels = 3 image.spatial_size = (24, 32) - bboxes = features.BoundingBox( + bboxes = datapoints.BoundingBox( torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), format="XYXY", spatial_size=image.spatial_size, @@ -1263,9 +1265,9 @@ def test__get_params(self, device, options, mocker): def test__transform_empty_params(self, mocker): transform = transforms.RandomIoUCrop(sampler_options=[2.0]) - image = features.Image(torch.rand(1, 3, 4, 4)) - bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) - label = features.Label(torch.tensor([1])) + image = datapoints.Image(torch.rand(1, 3, 4, 4)) + bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) + label = datapoints.Label(torch.tensor([1])) sample = [image, bboxes, label] # Let's mock transform._get_params to control the output: transform._get_params = mocker.MagicMock(return_value={}) @@ -1283,10 +1285,10 @@ def test_forward_assertion(self): def test__transform(self, mocker): transform = transforms.RandomIoUCrop() - image = features.Image(torch.rand(3, 32, 24)) + image = datapoints.Image(torch.rand(3, 32, 24)) bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,)) - label = features.Label(torch.randint(0, 10, size=(6,))) - ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) + label = datapoints.Label(torch.randint(0, 10, size=(6,))) + ohe_label = datapoints.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) masks = make_detection_mask((32, 24), num_objects=6) sample = [image, bboxes, label, ohe_label, masks] @@ -1312,21 +1314,21 @@ def test__transform(self, mocker): # check number of bboxes vs number of labels: output_bboxes = output[1] - assert isinstance(output_bboxes, features.BoundingBox) + assert isinstance(output_bboxes, datapoints.BoundingBox) assert len(output_bboxes) == expected_within_targets # check labels output_label = output[2] - assert isinstance(output_label, features.Label) + assert isinstance(output_label, datapoints.Label) assert len(output_label) == expected_within_targets torch.testing.assert_close(output_label, label[is_within_crop_area]) output_ohe_label = output[3] - assert isinstance(output_ohe_label, features.OneHotLabel) + assert isinstance(output_ohe_label, datapoints.OneHotLabel) torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) output_masks = output[4] - assert isinstance(output_masks, features.Mask) + assert isinstance(output_masks, datapoints.Mask) assert len(output_masks) == expected_within_targets @@ -1337,7 +1339,7 @@ def test__get_params(self, mocker): scale_range = (0.5, 1.5) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) - sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size) + sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size) n_samples = 5 for _ in range(n_samples): @@ -1387,7 +1389,7 @@ def test__get_params(self, min_size, max_size, mocker): transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) - sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size) + sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size) params = transform._get_params([sample]) assert "size" in params @@ -1439,21 +1441,21 @@ def test__extract_image_targets_assertion(self, mocker): flat_sample = [ # images, batch size = 2 - self.create_fake_image(mocker, features.Image), + self.create_fake_image(mocker, datapoints.Image), # labels, bboxes, masks - mocker.MagicMock(spec=features.Label), - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), + mocker.MagicMock(spec=datapoints.Label), + mocker.MagicMock(spec=datapoints.BoundingBox), + mocker.MagicMock(spec=datapoints.Mask), # labels, bboxes, masks - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), + mocker.MagicMock(spec=datapoints.BoundingBox), + mocker.MagicMock(spec=datapoints.Mask), ] with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"): transform._extract_image_targets(flat_sample) - @pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) - @pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel]) + @pytest.mark.parametrize("image_type", [datapoints.Image, PIL.Image.Image, torch.Tensor]) + @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) def test__extract_image_targets(self, image_type, label_type, mocker): transform = transforms.SimpleCopyPaste() @@ -1463,12 +1465,12 @@ def test__extract_image_targets(self, image_type, label_type, mocker): self.create_fake_image(mocker, image_type), # labels, bboxes, masks mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), + mocker.MagicMock(spec=datapoints.BoundingBox), + mocker.MagicMock(spec=datapoints.Mask), # labels, bboxes, masks mocker.MagicMock(spec=label_type), - mocker.MagicMock(spec=features.BoundingBox), - mocker.MagicMock(spec=features.Mask), + mocker.MagicMock(spec=datapoints.BoundingBox), + mocker.MagicMock(spec=datapoints.Mask), ] images, targets = transform._extract_image_targets(flat_sample) @@ -1483,15 +1485,15 @@ def test__extract_image_targets(self, image_type, label_type, mocker): for target in targets: for key, type_ in [ - ("boxes", features.BoundingBox), - ("masks", features.Mask), + ("boxes", datapoints.BoundingBox), + ("masks", datapoints.Mask), ("labels", label_type), ]: assert key in target assert isinstance(target[key], type_) assert target[key] in flat_sample - @pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel]) + @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) def test__copy_paste(self, label_type): image = 2 * torch.ones(3, 32, 32) masks = torch.zeros(2, 32, 32) @@ -1501,13 +1503,13 @@ def test__copy_paste(self, label_type): blending = True resize_interpolation = InterpolationMode.BILINEAR antialias = None - if label_type == features.OneHotLabel: + if label_type == datapoints.OneHotLabel: labels = torch.nn.functional.one_hot(labels, num_classes=5) target = { - "boxes": features.BoundingBox( + "boxes": datapoints.BoundingBox( torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) ), - "masks": features.Mask(masks), + "masks": datapoints.Mask(masks), "labels": label_type(labels), } @@ -1516,13 +1518,13 @@ def test__copy_paste(self, label_type): paste_masks[0, 13:19, 12:18] = 1 paste_masks[1, 15:19, 1:8] = 1 paste_labels = torch.tensor([3, 4]) - if label_type == features.OneHotLabel: + if label_type == datapoints.OneHotLabel: paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_target = { - "boxes": features.BoundingBox( + "boxes": datapoints.BoundingBox( torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) ), - "masks": features.Mask(paste_masks), + "masks": datapoints.Mask(paste_masks), "labels": label_type(paste_labels), } @@ -1538,7 +1540,7 @@ def test__copy_paste(self, label_type): torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) expected_labels = torch.tensor([1, 2, 3, 4]) - if label_type == features.OneHotLabel: + if label_type == datapoints.OneHotLabel: expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) @@ -1556,9 +1558,9 @@ def test__get_params(self, mocker): transform = transforms.FixedSizeCrop(size=crop_size) flat_inputs = [ - make_image(size=spatial_size, color_space=features.ColorSpace.RGB), + make_image(size=spatial_size, color_space=datapoints.ColorSpace.RGB), make_bounding_box( - format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape ), ] params = transform._get_params(flat_inputs) @@ -1656,7 +1658,7 @@ def test__transform_culling(self, mocker): ) bounding_boxes = make_bounding_box( - format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) ) masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,)) @@ -1695,7 +1697,7 @@ def test__transform_bounding_box_clamping(self, mocker): ) bounding_box = make_bounding_box( - format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) ) mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") @@ -1721,7 +1723,7 @@ def test_assertions(self): [ 122 * torch.ones(1, 3, 8, 8), 122.0 * torch.ones(1, 3, 8, 8), - features.Image(122 * torch.ones(1, 3, 8, 8)), + datapoints.Image(122 * torch.ones(1, 3, 8, 8)), PIL.Image.new("RGB", (8, 8), (122, 122, 122)), ], ) @@ -1744,10 +1746,10 @@ def test__transform(self, inpt): class TestLabelToOneHot: def test__transform(self): categories = ["apple", "pear", "pineapple"] - labels = features.Label(torch.tensor([0, 1, 2, 1]), categories=categories) + labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories) transform = transforms.LabelToOneHot() ohe_labels = transform(labels) - assert isinstance(ohe_labels, features.OneHotLabel) + assert isinstance(ohe_labels, datapoints.OneHotLabel) assert ohe_labels.shape == (4, 3) assert ohe_labels.categories == labels.categories == categories @@ -1797,11 +1799,11 @@ def test__transform(self, mocker): [ ( torch.float64, - {torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64}, + {torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, ), ( - {torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64}, - {torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64}, + {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, ), ], ) @@ -1809,7 +1811,7 @@ def test_to_dtype(dtype, expected_dtypes): sample = dict( plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"), image=make_image(dtype=torch.uint8), - bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), str="str", int=0, ) @@ -1834,12 +1836,12 @@ def test_to_dtype(dtype, expected_dtypes): ("dims", "inverse_dims"), [ ( - {torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None}, - {torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None}, + {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None}, + {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None}, ), ( - {torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)}, - {torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)}, + {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, + {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, ), ], ) @@ -1847,7 +1849,7 @@ def test_permute_dimensions(dims, inverse_dims): sample = dict( plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), image=make_image(), - bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), video=make_video(), str="str", int=0, @@ -1860,7 +1862,9 @@ def test_permute_dimensions(dims, inverse_dims): value_type = type(value) transformed_value = transformed_sample[key] - if check_type(value, (features.Image, features.is_simple_tensor, features.Video)): + if check_type( + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + ): if transform.dims.get(value_type) is not None: assert transformed_value.permute(inverse_dims[value_type]).equal(value) assert type(transformed_value) == torch.Tensor @@ -1872,14 +1876,14 @@ def test_permute_dimensions(dims, inverse_dims): "dims", [ (-1, -2), - {torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None}, + {torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None}, ], ) def test_transpose_dimensions(dims): sample = dict( plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), image=make_image(), - bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), video=make_video(), str="str", int=0, @@ -1893,7 +1897,9 @@ def test_transpose_dimensions(dims): transformed_value = transformed_sample[key] transposed_dims = transform.dims.get(value_type) - if check_type(value, (features.Image, features.is_simple_tensor, features.Video)): + if check_type( + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + ): if transposed_dims is not None: assert transformed_value.transpose(*transposed_dims).equal(value) assert type(transformed_value) == torch.Tensor @@ -1907,7 +1913,7 @@ class TestUniformTemporalSubsample: [ torch.zeros(10, 3, 8, 8), torch.zeros(1, 10, 3, 8, 8), - features.Video(torch.zeros(1, 10, 3, 8, 8)), + datapoints.Video(torch.zeros(1, 10, 3, 8, 8)), ], ) def test__transform(self, inpt): diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index f5738a36a38..f562649be91 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -24,13 +24,13 @@ ) from torchvision import transforms as legacy_transforms from torchvision._utils import sequence_to_str -from torchvision.prototype import features, transforms as prototype_transforms +from torchvision.prototype import datapoints, transforms as prototype_transforms from torchvision.prototype.transforms import functional as prototype_F from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.transforms import functional as legacy_F -DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) +DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)]) class ConsistencyConfig: @@ -138,7 +138,7 @@ def __init__( # Make sure that the product of the height, width and number of channels matches the number of elements in # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[features.ColorSpace.RGB] + DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB] ), supports_pil=False, ), @@ -150,7 +150,7 @@ def __init__( ArgsKwargs(num_output_channels=3), ], make_images_kwargs=dict( - DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY] + DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY] ), ), ConsistencyConfig( @@ -173,10 +173,10 @@ def __init__( [ArgsKwargs()], make_images_kwargs=dict( color_spaces=[ - features.ColorSpace.GRAY, - features.ColorSpace.GRAY_ALPHA, - features.ColorSpace.RGB, - features.ColorSpace.RGB_ALPHA, + datapoints.ColorSpace.GRAY, + datapoints.ColorSpace.GRAY_ALPHA, + datapoints.ColorSpace.RGB, + datapoints.ColorSpace.RGB_ALPHA, ], extra_dims=[()], ), @@ -733,7 +733,7 @@ class TestAATransforms: [ torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), ], ) @pytest.mark.parametrize( @@ -771,7 +771,7 @@ def test_randaug(self, inpt, interpolation, mocker): [ torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), ], ) @pytest.mark.parametrize( @@ -819,7 +819,7 @@ def test_trivial_aug(self, inpt, interpolation, mocker): [ torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), ], ) @pytest.mark.parametrize( @@ -868,7 +868,7 @@ def test_augmix(self, inpt, interpolation, mocker): [ torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), PIL.Image.new("RGB", (256, 256), 123), - features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), ], ) @pytest.mark.parametrize( @@ -902,7 +902,7 @@ def make_datapoints(self, with_mask=True): size = (600, 800) num_objects = 22 - pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) + pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -912,7 +912,7 @@ def make_datapoints(self, with_mask=True): yield (pil_image, target) - tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) + tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -922,7 +922,7 @@ def make_datapoints(self, with_mask=True): yield (tensor_image, target) - feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) + feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB) target = { "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -1006,7 +1006,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): conv_fns.extend([torch.Tensor, lambda x: x]) for conv_fn in conv_fns: - feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) + feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype) feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) dp = (conv_fn(feature_image), feature_mask) @@ -1053,7 +1053,7 @@ def check(self, t, t_ref, data_kwargs=None): seg_transforms.RandomCrop(size=480), prototype_transforms.Compose( [ - PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), + PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), prototype_transforms.RandomCrop(size=480), ] ), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 8837164452a..7cd84fbcd61 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -10,12 +10,14 @@ import pytest import torch + +import torchvision.prototype.transforms.utils from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS from torch.utils._pytree import tree_map -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box @@ -147,18 +149,22 @@ def _unbatch(self, batch, *, data_dims): def test_batched_vs_single(self, test_id, info, args_kwargs, device): (batched_input, *other_args), kwargs = args_kwargs.load(device) - feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input) + feature_type = ( + datapoints.Image + if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input) + else type(batched_input) + ) # This dictionary contains the number of rightmost dimensions that contain the actual data. # Everything to the left is considered a batch dimension. data_dims = { - features.Image: 3, - features.BoundingBox: 1, + datapoints.Image: 3, + datapoints.BoundingBox: 1, # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # common ground. - features.Mask: 2, - features.Video: 4, + datapoints.Mask: 2, + datapoints.Video: 4, }.get(feature_type) if data_dims is None: raise pytest.UsageError( @@ -281,8 +287,8 @@ def make_spy(fn, *, module=None, name=None): class TestDispatchers: image_sample_inputs = make_info_args_kwargs_parametrization( - [info for info in DISPATCHER_INFOS if features.Image in info.kernels], - args_kwargs_fn=lambda info: info.sample_inputs(features.Image), + [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels], + args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), ) @ignore_jit_warning_no_profile @@ -323,7 +329,7 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): (image_feature, *other_args), kwargs = args_kwargs.load() image_simple_tensor = torch.Tensor(image_feature) - kernel_info = info.kernel_infos[features.Image] + kernel_info = info.kernel_infos[datapoints.Image] spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) info.dispatcher(image_simple_tensor, *other_args, **kwargs) @@ -332,7 +338,7 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): @make_info_args_kwargs_parametrization( [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], - args_kwargs_fn=lambda info: info.sample_inputs(features.Image), + args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), ) def test_dispatch_pil(self, info, args_kwargs, spy_on): (image_feature, *other_args), kwargs = args_kwargs.load() @@ -403,7 +409,7 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) def test_dispatcher_feature_signatures_consistency(self, info): try: - feature_method = getattr(features._Feature, info.id) + feature_method = getattr(datapoints._datapoint.Datapoint, info.id) except AttributeError: pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") @@ -413,7 +419,7 @@ def test_dispatcher_feature_signatures_consistency(self, info): feature_signature = inspect.signature(feature_method) feature_params = list(feature_signature.parameters.values())[1:] - # Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined, + # Because we use `from __future__ import annotations` inside the module where `features._datapoint` is defined, # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively # concrete dispatcher annotations. feature_annotations = get_type_hints(feature_method) @@ -505,8 +511,12 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [1, 1, 5, 5], ] - in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device + in_boxes = datapoints.BoundingBox( + in_boxes, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + dtype=torch.float64, + device=device, ) # Tested parameters angle = 63 @@ -572,7 +582,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): height, width = bbox.spatial_size bbox_xyxy = convert_format_bounding_box( - bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY + bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY ) points = np.array( [ @@ -605,15 +615,15 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): height = int(height - 2 * tr_y) width = int(width - 2 * tr_x) - out_bbox = features.BoundingBox( + out_bbox = datapoints.BoundingBox( out_bbox, - format=features.BoundingBoxFormat.XYXY, + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(height, width), dtype=bbox.dtype, device=bbox.device, ) return ( - convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format), + convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format), (height, width), ) @@ -641,7 +651,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_) expected_bboxes.append(expected_bbox) if len(expected_bboxes) > 1: @@ -664,8 +674,12 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], ] - in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device + in_boxes = datapoints.BoundingBox( + in_boxes, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + dtype=torch.float64, + device=device, ) # Tested parameters angle = 45 @@ -725,7 +739,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( "format", - [features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], + [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH], ) @pytest.mark.parametrize( "top, left, height, width, expected_bboxes", @@ -755,9 +769,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, [50.0, 5.0, 70.0, 22.0], [45.0, 46.0, 56.0, 62.0], ] - in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=size, device=device) - if format != features.BoundingBoxFormat.XYXY: - in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) + in_boxes = datapoints.BoundingBox( + in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device + ) + if format != datapoints.BoundingBoxFormat.XYXY: + in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) output_boxes, output_spatial_size = F.crop_bounding_box( in_boxes, @@ -768,8 +784,8 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, size[1], ) - if format != features.BoundingBoxFormat.XYXY: - output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) + if format != datapoints.BoundingBoxFormat.XYXY: + output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_spatial_size, size) @@ -802,7 +818,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( "format", - [features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], + [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH], ) @pytest.mark.parametrize( "top, left, height, width, size", @@ -831,16 +847,16 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes = torch.tensor(expected_bboxes, device=device) - in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device + in_boxes = datapoints.BoundingBox( + in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device ) - if format != features.BoundingBoxFormat.XYXY: - in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) + if format != datapoints.BoundingBoxFormat.XYXY: + in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) - if format != features.BoundingBoxFormat.XYXY: - output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) + if format != datapoints.BoundingBoxFormat.XYXY: + output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_spatial_size, size) @@ -868,14 +884,14 @@ def _compute_expected_bbox(bbox, padding_): bbox_dtype = bbox.dtype bbox = ( bbox.clone() - if bbox_format == features.BoundingBoxFormat.XYXY - else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY) + if bbox_format == datapoints.BoundingBoxFormat.XYXY + else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY) ) bbox[0::2] += pad_left bbox[1::2] += pad_up - bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format) + bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format) if bbox.dtype != bbox_dtype: # Temporary cast to original dtype # e.g. float32 -> int @@ -903,7 +919,7 @@ def _compute_expected_spatial_size(bbox, padding_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) expected_bboxes.append(_compute_expected_bbox(bbox, padding)) if len(expected_bboxes) > 1: @@ -949,7 +965,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): ) bbox_xyxy = convert_format_bounding_box( - bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY + bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY ) points = np.array( [ @@ -968,14 +984,16 @@ def _compute_expected_bbox(bbox, pcoeffs_): np.max(transformed_points[:, 0]), np.max(transformed_points[:, 1]), ] - out_bbox = features.BoundingBox( + out_bbox = datapoints.BoundingBox( np.array(out_bbox), - format=features.BoundingBoxFormat.XYXY, + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=bbox.spatial_size, dtype=bbox.dtype, device=bbox.device, ) - return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format) + return convert_format_bounding_box( + out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format + ) spatial_size = (32, 38) @@ -1000,7 +1018,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) @@ -1019,7 +1037,7 @@ def _compute_expected_bbox(bbox, output_size_): format_ = bbox.format spatial_size_ = bbox.spatial_size dtype = bbox.dtype - bbox = convert_format_bounding_box(bbox.float(), format_, features.BoundingBoxFormat.XYWH) + bbox = convert_format_bounding_box(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) if len(output_size_) == 1: output_size_.append(output_size_[-1]) @@ -1033,7 +1051,7 @@ def _compute_expected_bbox(bbox, output_size_): bbox[3].item(), ] out_bbox = torch.tensor(out_bbox) - out_bbox = convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_) + out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) return out_bbox.to(dtype=dtype, device=bbox.device) for bboxes in make_bounding_boxes(extra_dims=((4,),)): @@ -1050,7 +1068,7 @@ def _compute_expected_bbox(bbox, output_size_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) + bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) if len(expected_bboxes) > 1: @@ -1135,7 +1153,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) ) - image = features.Image(tensor) + image = datapoints.Image(tensor) out = fn(image, kernel_size=ksize, sigma=sigma) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") @@ -1147,7 +1165,7 @@ def test_normalize_output_type(): assert type(output) is torch.Tensor torch.testing.assert_close(inpt - 0.5, output) - inpt = make_image(color_space=features.ColorSpace.RGB) + inpt = make_image(color_space=datapoints.ColorSpace.RGB) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) assert type(output) is torch.Tensor torch.testing.assert_close(inpt - 0.5, output) diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index 69b23bf12af..8774b3bb8c5 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -3,42 +3,51 @@ import torch +import torchvision.prototype.transforms.utils from prototype_common_utils import make_bounding_box, make_detection_mask, make_image -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.utils import has_all, has_any -IMAGE = make_image(color_space=features.ColorSpace.RGB) -BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) +IMAGE = make_image(color_space=datapoints.ColorSpace.RGB) +BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size) @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), - ((MASK,), (features.Image, features.BoundingBox), False), - ((BOUNDING_BOX,), (features.Image, features.Mask), False), - ((IMAGE,), (features.BoundingBox, features.Mask), False), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True), + ((MASK,), (datapoints.Image, datapoints.BoundingBox), False), + ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False), + ((IMAGE,), (datapoints.BoundingBox, datapoints.Mask), False), ( (IMAGE, BOUNDING_BOX, MASK), - (features.Image, features.BoundingBox, features.Mask), + (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), True, ), - ((), (features.Image, features.BoundingBox, features.Mask), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, features.Image),), True), + ((), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), + ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), - ((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), - ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), - ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), + ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True), + ( + (torch.Tensor(IMAGE),), + (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), + True, + ), + ( + (to_image_pil(IMAGE),), + (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), + True, + ), ], ) def test_has_any(sample, types, expected): @@ -48,31 +57,31 @@ def test_has_any(sample, types, expected): @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True), ( (IMAGE, BOUNDING_BOX, MASK), - (features.Image, features.BoundingBox, features.Mask), + (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), True, ), - ((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), False), - ((BOUNDING_BOX, MASK), (features.Image, features.Mask), False), - ((IMAGE, MASK), (features.BoundingBox, features.Mask), False), + ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), False), + ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False), + ((IMAGE, MASK), (datapoints.BoundingBox, datapoints.Mask), False), ( (IMAGE, BOUNDING_BOX, MASK), - (features.Image, features.BoundingBox, features.Mask), + (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), True, ), - ((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox, features.Mask), False), - ((IMAGE, MASK), (features.Image, features.BoundingBox, features.Mask), False), - ((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.Mask), False), + ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), + ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), + ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False), ( (IMAGE, BOUNDING_BOX, MASK), - (lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.Mask)),), + (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),), True, ), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index 0edf8eb2e9f..200f5cd9552 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1 +1 @@ -from . import features, models, transforms, utils +from . import datapoints, models, transforms, utils diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/datapoints/__init__.py similarity index 76% rename from torchvision/prototype/features/__init__.py rename to torchvision/prototype/datapoints/__init__.py index e11e99a9bef..92f345e20bd 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,5 +1,5 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat -from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/datapoints/_bounding_box.py similarity index 98% rename from torchvision/prototype/features/_bounding_box.py rename to torchvision/prototype/datapoints/_bounding_box.py index a91a50ecb2b..398770cbf6a 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/datapoints/_bounding_box.py @@ -6,7 +6,7 @@ from torchvision._utils import StrEnum from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms -from ._feature import _Feature, FillTypeJIT +from ._datapoint import Datapoint, FillTypeJIT class BoundingBoxFormat(StrEnum): @@ -15,7 +15,7 @@ class BoundingBoxFormat(StrEnum): CXCYWH = StrEnum.auto() -class BoundingBox(_Feature): +class BoundingBox(Datapoint): format: BoundingBoxFormat spatial_size: Tuple[int, int] diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/datapoints/_datapoint.py similarity index 80% rename from torchvision/prototype/features/_feature.py rename to torchvision/prototype/datapoints/_datapoint.py index 3d76236451f..53d1b05fb3b 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -10,16 +10,12 @@ from torchvision.transforms import InterpolationMode -F = TypeVar("F", bound="_Feature") +D = TypeVar("D", bound="Datapoint") FillType = Union[int, float, Sequence[int], Sequence[float], None] FillTypeJIT = Union[int, float, List[float], None] -def is_simple_tensor(inpt: Any) -> bool: - return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature) - - -class _Feature(torch.Tensor): +class Datapoint(torch.Tensor): __F: Optional[ModuleType] = None @staticmethod @@ -31,22 +27,22 @@ def _to_tensor( ) -> torch.Tensor: return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) - # FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a + # FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a # a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be - # interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one - # public again. + # interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this + # one public again. def __new__( cls, data: Any, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, - ) -> _Feature: + ) -> Datapoint: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return tensor.as_subclass(_Feature) + return tensor.as_subclass(Datapoint) @classmethod - def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F: + def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, # this method should be made abstract # raise NotImplementedError @@ -75,15 +71,15 @@ def __torch_function__( ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``args`` and ``kwargs`` of the original call. - The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature` + The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint` use case, this has two downsides: - 1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. + 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e. ``return cls(func(*args, **kwargs))``, will fail for them. 2. For most operations, there is no way of knowing if the input type is still valid for the output. For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are - listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS` + listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` """ # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we # need to reimplement the functionality. @@ -98,9 +94,9 @@ def __torch_function__( # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, - # `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with - # `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would - # be wrapped into a `features.Image`. + # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with + # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would + # be wrapped into a `datapoints.Image`. if wrapper and isinstance(args[0], cls): return wrapper(cls, args[0], output) # type: ignore[no-any-return] @@ -123,11 +119,11 @@ def _F(self) -> ModuleType: # until the first time we need reference to the functional module and it's shared across all instances of # the class. This approach avoids the DataLoader issue described at # https://github.com/pytorch/vision/pull/6476#discussion_r953588621 - if _Feature.__F is None: + if Datapoint.__F is None: from ..transforms import functional - _Feature.__F = functional - return _Feature.__F + Datapoint.__F = functional + return Datapoint.__F # Add properties for common attributes like shape, dtype, device, ndim etc # this way we return the result without passing into __torch_function__ @@ -151,10 +147,10 @@ def dtype(self) -> _dtype: # type: ignore[override] with DisableTorchFunction(): return super().dtype - def horizontal_flip(self) -> _Feature: + def horizontal_flip(self) -> Datapoint: return self - def vertical_flip(self) -> _Feature: + def vertical_flip(self) -> Datapoint: return self # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize @@ -165,13 +161,13 @@ def resize( # type: ignore[override] interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, - ) -> _Feature: + ) -> Datapoint: return self - def crop(self, top: int, left: int, height: int, width: int) -> _Feature: + def crop(self, top: int, left: int, height: int, width: int) -> Datapoint: return self - def center_crop(self, output_size: List[int]) -> _Feature: + def center_crop(self, output_size: List[int]) -> Datapoint: return self def resized_crop( @@ -183,7 +179,7 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, - ) -> _Feature: + ) -> Datapoint: return self def pad( @@ -191,7 +187,7 @@ def pad( padding: Union[int, List[int]], fill: FillTypeJIT = None, padding_mode: str = "constant", - ) -> _Feature: + ) -> Datapoint: return self def rotate( @@ -201,7 +197,7 @@ def rotate( expand: bool = False, center: Optional[List[float]] = None, fill: FillTypeJIT = None, - ) -> _Feature: + ) -> Datapoint: return self def affine( @@ -213,7 +209,7 @@ def affine( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: FillTypeJIT = None, center: Optional[List[float]] = None, - ) -> _Feature: + ) -> Datapoint: return self def perspective( @@ -223,7 +219,7 @@ def perspective( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, coefficients: Optional[List[float]] = None, - ) -> _Feature: + ) -> Datapoint: return self def elastic( @@ -231,45 +227,45 @@ def elastic( displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: FillTypeJIT = None, - ) -> _Feature: + ) -> Datapoint: return self - def adjust_brightness(self, brightness_factor: float) -> _Feature: + def adjust_brightness(self, brightness_factor: float) -> Datapoint: return self - def adjust_saturation(self, saturation_factor: float) -> _Feature: + def adjust_saturation(self, saturation_factor: float) -> Datapoint: return self - def adjust_contrast(self, contrast_factor: float) -> _Feature: + def adjust_contrast(self, contrast_factor: float) -> Datapoint: return self - def adjust_sharpness(self, sharpness_factor: float) -> _Feature: + def adjust_sharpness(self, sharpness_factor: float) -> Datapoint: return self - def adjust_hue(self, hue_factor: float) -> _Feature: + def adjust_hue(self, hue_factor: float) -> Datapoint: return self - def adjust_gamma(self, gamma: float, gain: float = 1) -> _Feature: + def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint: return self - def posterize(self, bits: int) -> _Feature: + def posterize(self, bits: int) -> Datapoint: return self - def solarize(self, threshold: float) -> _Feature: + def solarize(self, threshold: float) -> Datapoint: return self - def autocontrast(self) -> _Feature: + def autocontrast(self) -> Datapoint: return self - def equalize(self) -> _Feature: + def equalize(self) -> Datapoint: return self - def invert(self) -> _Feature: + def invert(self) -> Datapoint: return self - def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: + def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint: return self -InputType = Union[torch.Tensor, PIL.Image.Image, _Feature] +InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] InputTypeJIT = torch.Tensor diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/datapoints/_image.py similarity index 99% rename from torchvision/prototype/features/_image.py rename to torchvision/prototype/datapoints/_image.py index fd04e89393c..fc20691100f 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -8,7 +8,7 @@ from torchvision._utils import StrEnum from torchvision.transforms.functional import InterpolationMode -from ._feature import _Feature, FillTypeJIT +from ._datapoint import Datapoint, FillTypeJIT class ColorSpace(StrEnum): @@ -57,7 +57,7 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: return ColorSpace.OTHER -class Image(_Feature): +class Image(Datapoint): color_space: ColorSpace @classmethod diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/datapoints/_label.py similarity index 97% rename from torchvision/prototype/features/_label.py rename to torchvision/prototype/datapoints/_label.py index 9c2bcfc0fb1..54915493390 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -5,13 +5,13 @@ import torch from torch.utils._pytree import tree_map -from ._feature import _Feature +from ._datapoint import Datapoint L = TypeVar("L", bound="_LabelBase") -class _LabelBase(_Feature): +class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @classmethod diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/datapoints/_mask.py similarity index 98% rename from torchvision/prototype/features/_mask.py rename to torchvision/prototype/datapoints/_mask.py index eb823f82491..ca4aba87d2e 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -5,10 +5,10 @@ import torch from torchvision.transforms import InterpolationMode -from ._feature import _Feature, FillTypeJIT +from ._datapoint import Datapoint, FillTypeJIT -class Mask(_Feature): +class Mask(Datapoint): @classmethod def _wrap(cls, tensor: torch.Tensor) -> Mask: return tensor.as_subclass(cls) diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/datapoints/_video.py similarity index 99% rename from torchvision/prototype/features/_video.py rename to torchvision/prototype/datapoints/_video.py index 042f643e5f8..5c55d23a149 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -6,11 +6,11 @@ import torch from torchvision.transforms.functional import InterpolationMode -from ._feature import _Feature, FillTypeJIT +from ._datapoint import Datapoint, FillTypeJIT from ._image import ColorSpace -class Video(_Feature): +class Video(Datapoint): color_space: ColorSpace @classmethod diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index eadc2a019f6..55a77c1a920 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -4,6 +4,8 @@ import numpy as np from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper +from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, @@ -12,7 +14,6 @@ read_categories_file, read_mat, ) -from torchvision.prototype.features import _Feature, BoundingBox, Label from .._api import register_dataset, register_info @@ -114,7 +115,7 @@ def _prepare_sample( format="xyxy", spatial_size=image.spatial_size, ), - contour=_Feature(ann["obj_contour"].T), + contour=Datapoint(ann["obj_contour"].T), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 12771a11efa..9050cf0b596 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -3,6 +3,8 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper +from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -11,7 +13,6 @@ INFINITE_BUFFER_SIZE, path_accessor, ) -from torchvision.prototype.features import _Feature, BoundingBox, Label from .._api import register_dataset, register_info @@ -148,7 +149,7 @@ def _prepare_sample( spatial_size=image.spatial_size, ), landmarks={ - landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) + landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) for landmark in {key[:-2] for key in landmarks.keys()} }, ) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 0fff2e6a136..de87f46c8b1 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -6,6 +6,7 @@ import numpy as np from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper +from torchvision.prototype.datapoints import Image, Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, @@ -13,7 +14,6 @@ path_comparator, read_categories_file, ) -from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 753a28363c6..e282635684e 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -2,6 +2,7 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -11,7 +12,6 @@ path_accessor, path_comparator, ) -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 4ec4580e780..fa68bf4dc6f 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -14,6 +14,8 @@ Mapper, UnBatcher, ) +from torchvision.prototype.datapoints import BoundingBox, Label, Mask +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -24,7 +26,6 @@ path_accessor, read_categories_file, ) -from torchvision.prototype.features import _Feature, BoundingBox, Label from .._api import register_dataset, register_info @@ -113,8 +114,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st spatial_size = (image_meta["height"], image_meta["width"]) labels = [ann["category_id"] for ann in anns] return dict( - # TODO: create a segmentation feature - segmentations=_Feature( + segmentations=Mask( torch.stack( [ self._segmentation_to_mask( @@ -124,8 +124,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st ] ) ), - areas=_Feature([ann["area"] for ann in anns]), - crowds=_Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), + areas=Datapoint([ann["area"] for ann in anns]), + crowds=Datapoint([ann["iscrowd"] for ann in anns], dtype=torch.bool), bounding_boxes=BoundingBox( [ann["bbox"] for ann in anns], format="xywh", diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index c006e445491..0f4b3d769dc 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, @@ -9,7 +10,6 @@ path_comparator, read_categories_file, ) -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 2a88f703014..ea192baf650 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -14,6 +14,8 @@ Mapper, ) from torchdata.datapipes.map import IterToMapConverter +from torchvision.prototype.datapoints import BoundingBox, Label +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -25,7 +27,6 @@ read_categories_file, read_mat, ) -from torchvision.prototype.features import _Feature, BoundingBox, Label from .._api import register_dataset, register_info @@ -161,7 +162,7 @@ def _2010_prepare_ann( format="xyxy", spatial_size=spatial_size, ), - segmentation=_Feature(content["seg"]), + segmentation=Datapoint(content["seg"]), ) def _prepare_sample( diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index ebd5eaec571..6ddab2af79d 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -3,6 +3,7 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -12,7 +13,6 @@ path_comparator, read_categories_file, ) -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 12a379b47b5..463eed79d70 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index b2693aa96c0..73c6184b6e7 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -3,9 +3,9 @@ import torch from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper +from torchvision.prototype.datapoints import Image, Label from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index 122962599af..f3054d8fb13 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -2,6 +2,7 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -11,7 +12,6 @@ path_comparator, read_categories_file, ) -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 73eee8435e3..adcc31b277a 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper +from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, @@ -9,7 +10,6 @@ INFINITE_BUFFER_SIZE, path_comparator, ) -from torchvision.prototype.features import BoundingBox, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 8388285e541..5e2db41e1d0 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -15,6 +15,7 @@ TarArchiveLoader, ) from torchdata.datapipes.map import IterToMapConverter +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -25,7 +26,6 @@ read_categories_file, read_mat, ) -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 97d729d530b..9364aa3ade9 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,9 +7,9 @@ import torch from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper +from torchvision.prototype.datapoints import Image, Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE -from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 7621e7b1163..fbc7d30c292 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -3,6 +3,7 @@ from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -13,7 +14,6 @@ path_comparator, read_categories_file, ) -from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index f533ba18084..9de224b95f0 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -4,10 +4,9 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper -from torchvision.prototype import features +from torchvision.prototype.datapoints import Image, Label from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Label from .._api import register_dataset, register_info @@ -109,7 +108,7 @@ def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: image, target = data # They're both numpy arrays at this point return { - "image": features.Image(image.transpose(2, 0, 1)), + "image": Image(image.transpose(2, 0, 1)), "label": Label(target.item(), categories=self._categories), } diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 01dd1d888f5..c9f054b2c9e 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -4,6 +4,7 @@ import numpy as np from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -15,7 +16,6 @@ read_categories_file, read_mat, ) -from torchvision.prototype.features import _Feature from .._api import register_dataset, register_info @@ -92,8 +92,8 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st image=EncodedImage.from_file(image_buffer), ann_path=ann_path, # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])), - segmentation=_Feature(anns["Segmentation"].item()), + boundaries=Datapoint(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])), + segmentation=Datapoint(anns["Segmentation"].item()), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 8107f6565e4..9ae2c17ab5d 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -3,9 +3,9 @@ import torch from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper +from torchvision.prototype.datapoints import Image, OneHotLabel from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Image, OneHotLabel from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 82aec31295e..02db37169c1 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,6 +2,7 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper +from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, @@ -10,7 +11,6 @@ read_categories_file, read_mat, ) -from torchvision.prototype.features import BoundingBox, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 6dd55a77c99..d276298ca02 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -3,9 +3,9 @@ import numpy as np from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher +from torchvision.prototype.datapoints import Image, Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat -from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index e5ca58f8428..7d1fed04e07 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -3,9 +3,9 @@ import torch from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper +from torchvision.prototype.datapoints import Image, Label from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 901f8eeb1cd..d14189132be 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -6,6 +6,7 @@ from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.datasets import VOCDetection +from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( getitem, @@ -16,7 +17,6 @@ path_comparator, read_categories_file, ) -from torchvision.prototype.features import BoundingBox, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 01a93a52af1..0a37df03add 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -5,9 +5,9 @@ from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper +from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import EncodedData, EncodedImage from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Label __all__ = ["from_data_folder", "from_image_folder"] diff --git a/torchvision/prototype/datasets/utils/_encoded.py b/torchvision/prototype/datasets/utils/_encoded.py index 1e06878ba74..64cd9f7b951 100644 --- a/torchvision/prototype/datasets/utils/_encoded.py +++ b/torchvision/prototype/datasets/utils/_encoded.py @@ -7,13 +7,13 @@ import PIL.Image import torch -from torchvision.prototype.features._feature import _Feature +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer D = TypeVar("D", bound="EncodedData") -class EncodedData(_Feature): +class EncodedData(Datapoint): @classmethod def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 8ec7929cdc5..23238c7a5fa 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -6,16 +6,17 @@ import PIL.Image import torch from torch.utils._pytree import tree_flatten, tree_unflatten + from torchvision.ops import masks_to_boxes -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, InterpolationMode from ._transform import _RandomApplyTransform -from .utils import has_any, query_chw, query_spatial_size +from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size class RandomErasing(_RandomApplyTransform): - _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) + _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) def __init__( self, @@ -91,8 +92,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) def _transform( - self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] - ) -> Union[features.ImageType, features.VideoType]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) @@ -107,20 +108,20 @@ def __init__(self, alpha: float, p: float = 0.5) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( - has_any(flat_inputs, features.Image, features.Video, features.is_simple_tensor) - and has_any(flat_inputs, features.OneHotLabel) + has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) + and has_any(flat_inputs, datapoints.OneHotLabel) ): raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") - if has_any(flat_inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): + if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Label): raise TypeError( f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." ) - def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: + def _mixup_onehotlabel(self, inpt: datapoints.OneHotLabel, lam: float) -> datapoints.OneHotLabel: if inpt.ndim < 2: raise ValueError("Need a batch of one hot labels") output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - return features.OneHotLabel.wrap_like(inpt, output) + return datapoints.OneHotLabel.wrap_like(inpt, output) class RandomMixup(_BaseMixupCutmix): @@ -129,17 +130,17 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] - if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt): - expected_ndim = 5 if isinstance(inpt, features.Video) else 4 + if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 if inpt.ndim < expected_ndim: raise ValueError("The transform expects a batched input") output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - if isinstance(inpt, (features.Image, features.Video)): + if isinstance(inpt, (datapoints.Image, datapoints.Video)): output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] return output - elif isinstance(inpt, features.OneHotLabel): + elif isinstance(inpt, datapoints.OneHotLabel): return self._mixup_onehotlabel(inpt, lam) else: return inpt @@ -169,9 +170,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt): + if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): box = params["box"] - expected_ndim = 5 if isinstance(inpt, features.Video) else 4 + expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 if inpt.ndim < expected_ndim: raise ValueError("The transform expects a batched input") x1, y1, x2, y2 = box @@ -179,11 +180,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = inpt.clone() output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] - if isinstance(inpt, (features.Image, features.Video)): + if isinstance(inpt, (datapoints.Image, datapoints.Video)): output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output - elif isinstance(inpt, features.OneHotLabel): + elif isinstance(inpt, datapoints.OneHotLabel): lam_adjusted = params["lam_adjusted"] return self._mixup_onehotlabel(inpt, lam_adjusted) else: @@ -205,15 +206,15 @@ def __init__( def _copy_paste( self, - image: features.TensorImageType, + image: datapoints.TensorImageType, target: Dict[str, Any], - paste_image: features.TensorImageType, + paste_image: datapoints.TensorImageType, paste_target: Dict[str, Any], random_selection: torch.Tensor, blending: bool, resize_interpolation: F.InterpolationMode, antialias: Optional[bool], - ) -> Tuple[features.TensorImageType, Dict[str, Any]]: + ) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]: paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) @@ -262,7 +263,7 @@ def _copy_paste( # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 xyxy_boxes[:, 2:] += 1 boxes = F.convert_format_bounding_box( - xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True + xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True ) out_target["boxes"] = torch.cat([boxes, paste_boxes]) @@ -271,7 +272,7 @@ def _copy_paste( # Check for degenerated boxes and remove them boxes = F.convert_format_bounding_box( - out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY + out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY ) degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): @@ -285,20 +286,20 @@ def _copy_paste( def _extract_image_targets( self, flat_sample: List[Any] - ) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]: + ) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input # with List[image], List[BoundingBox], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] for obj in flat_sample: - if isinstance(obj, features.Image) or features.is_simple_tensor(obj): + if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): images.append(obj) elif isinstance(obj, PIL.Image.Image): images.append(F.to_image_tensor(obj)) - elif isinstance(obj, features.BoundingBox): + elif isinstance(obj, datapoints.BoundingBox): bboxes.append(obj) - elif isinstance(obj, features.Mask): + elif isinstance(obj, datapoints.Mask): masks.append(obj) - elif isinstance(obj, (features.Label, features.OneHotLabel)): + elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): labels.append(obj) if not (len(images) == len(bboxes) == len(masks) == len(labels)): @@ -316,27 +317,27 @@ def _extract_image_targets( def _insert_outputs( self, flat_sample: List[Any], - output_images: List[features.TensorImageType], + output_images: List[datapoints.TensorImageType], output_targets: List[Dict[str, Any]], ) -> None: c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): - if isinstance(obj, features.Image): - flat_sample[i] = features.Image.wrap_like(obj, output_images[c0]) + if isinstance(obj, datapoints.Image): + flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) c0 += 1 elif isinstance(obj, PIL.Image.Image): flat_sample[i] = F.to_image_pil(output_images[c0]) c0 += 1 - elif features.is_simple_tensor(obj): + elif is_simple_tensor(obj): flat_sample[i] = output_images[c0] c0 += 1 - elif isinstance(obj, features.BoundingBox): - flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) + elif isinstance(obj, datapoints.BoundingBox): + flat_sample[i] = datapoints.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) c1 += 1 - elif isinstance(obj, features.Mask): - flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"]) + elif isinstance(obj, datapoints.Mask): + flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) c2 += 1 - elif isinstance(obj, (features.Label, features.OneHotLabel)): + elif isinstance(obj, (datapoints.Label, datapoints.OneHotLabel)): flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] c3 += 1 diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 28029db8215..d4f2ca2143b 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -5,13 +5,14 @@ import torch from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec -from torchvision.prototype import features + +from torchvision.prototype import datapoints from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_spatial_size from torchvision.transforms import functional_tensor as _FT from ._utils import _setup_fill_arg -from .utils import check_type +from .utils import check_type, is_simple_tensor class _AutoAugmentBase(Transform): @@ -19,7 +20,7 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[features.FillType, Dict[Type, features.FillType]] = None, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__() self.interpolation = interpolation @@ -33,13 +34,21 @@ def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, def _flatten_and_extract_image_or_video( self, inputs: Any, - unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), - ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[features.ImageType, features.VideoType]]: + unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), + ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) image_or_videos = [] for idx, inpt in enumerate(flat_inputs): - if check_type(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)): + if check_type( + inpt, + ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ), + ): image_or_videos.append((idx, inpt)) elif isinstance(inpt, unsupported_types): raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") @@ -58,7 +67,7 @@ def _flatten_and_extract_image_or_video( def _unflatten_and_insert_image_or_video( self, flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], - image_or_video: Union[features.ImageType, features.VideoType], + image_or_video: Union[datapoints.ImageType, datapoints.VideoType], ) -> Any: flat_inputs, spec, idx = flat_inputs_with_spec flat_inputs[idx] = image_or_video @@ -66,12 +75,12 @@ def _unflatten_and_insert_image_or_video( def _apply_image_or_video_transform( self, - image: Union[features.ImageType, features.VideoType], + image: Union[datapoints.ImageType, datapoints.VideoType], transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Dict[Type, features.FillTypeJIT], - ) -> Union[features.ImageType, features.VideoType]: + fill: Dict[Type, datapoints.FillTypeJIT], + ) -> Union[datapoints.ImageType, datapoints.VideoType]: fill_ = fill[type(image)] if transform_id == "Identity": @@ -182,7 +191,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[features.FillType, Dict[Type, features.FillType]] = None, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -338,7 +347,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[features.FillType, Dict[Type, features.FillType]] = None, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -390,7 +399,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[features.FillType, Dict[Type, features.FillType]] = None, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -446,7 +455,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Union[features.FillType, Dict[Type, features.FillType]] = None, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 @@ -474,7 +483,7 @@ def forward(self, *inputs: Any) -> Any: augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE orig_dims = list(image_or_video.shape) - expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4 + expected_ndim = 5 if isinstance(orig_image_or_video, datapoints.Video) else 4 batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) @@ -511,7 +520,7 @@ def forward(self, *inputs: Any) -> Any: mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) - if isinstance(orig_image_or_video, (features.Image, features.Video)): + if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] elif isinstance(orig_image_or_video, PIL.Image.Image): mix = F.to_image_pil(mix) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 49b2c098763..0254dd7c225 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -3,11 +3,12 @@ import PIL.Image import torch -from torchvision.prototype import features + +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform from ._transform import _RandomApplyTransform -from .utils import query_chw +from .utils import is_simple_tensor, query_chw class ColorJitter(Transform): @@ -82,7 +83,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomPhotometricDistort(Transform): - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) def __init__( self, @@ -111,15 +117,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: ) def _permute_channels( - self, inpt: Union[features.ImageType, features.VideoType], permutation: torch.Tensor - ) -> Union[features.ImageType, features.VideoType]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor + ) -> Union[datapoints.ImageType, datapoints.VideoType]: if isinstance(inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) output = inpt[..., permutation, :, :] - if isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type] elif isinstance(inpt, PIL.Image.Image): output = F.to_image_pil(output) @@ -127,8 +133,8 @@ def _permute_channels( return output def _transform( - self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] - ) -> Union[features.ImageType, features.VideoType]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: if params["brightness"]: inpt = F.adjust_brightness( inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 593eb8895db..3247a8051a3 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -5,13 +5,13 @@ import PIL.Image import torch -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms import Transform from torchvision.transforms import functional as _F from typing_extensions import Literal from ._transform import _RandomApplyTransform -from .utils import query_chw +from .utils import is_simple_tensor, query_chw class ToTensor(Transform): @@ -29,7 +29,12 @@ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, class Grayscale(Transform): - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: deprecation_msg = ( @@ -53,16 +58,21 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: self.num_output_channels = num_output_channels def _transform( - self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] - ) -> Union[features.ImageType, features.VideoType]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) - if isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] return output class RandomGrayscale(_RandomApplyTransform): - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) def __init__(self, p: float = 0.1) -> None: warnings.warn( @@ -84,9 +94,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform( - self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] - ) -> Union[features.ImageType, features.VideoType]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) - if isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] return output diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c5313c2655c..1cbf02d5ae2 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -5,8 +5,9 @@ import PIL.Image import torch + from torchvision.ops.boxes import box_iou -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from torchvision.transforms.functional import _get_perspective_coeffs @@ -22,7 +23,7 @@ _setup_float_or_seq, _setup_size, ) -from .utils import has_all, has_any, query_bounding_box, query_spatial_size +from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size class RandomHorizontalFlip(_RandomApplyTransform): @@ -145,23 +146,23 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) -ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT] +ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] class FiveCrop(Transform): """ Example: >>> class BatchMultiCrop(transforms.Transform): - ... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]): + ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], datapoints.Label]): ... images_or_videos, labels = sample ... batch_size = len(images_or_videos) ... image_or_video = images_or_videos[0] ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) - ... labels = features.Label.wrap_like(labels, labels.repeat(batch_size)) + ... labels = datapoints.Label.wrap_like(labels, labels.repeat(batch_size)) ... return images_or_videos, labels ... - >>> image = features.Image(torch.rand(3, 256, 256)) - >>> label = features.Label(0) + >>> image = datapoints.Image(torch.rand(3, 256, 256)) + >>> label = datapoints.Label(0) >>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()]) >>> images, labels = transform(image, label) >>> images.shape @@ -170,7 +171,12 @@ class FiveCrop(Transform): torch.Size([5]) """ - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() @@ -182,7 +188,7 @@ def _transform( return F.five_crop(inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: - if has_any(flat_inputs, features.BoundingBox, features.Mask): + if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") @@ -191,7 +197,12 @@ class TenCrop(Transform): See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. """ - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() @@ -199,12 +210,12 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.vertical_flip = vertical_flip def _check_inputs(self, flat_inputs: List[Any]) -> None: - if has_any(flat_inputs, features.BoundingBox, features.Mask): + if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") def _transform( - self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] - ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) @@ -212,7 +223,7 @@ class Pad(Transform): def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -235,7 +246,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -276,7 +287,7 @@ def __init__( degrees: Union[numbers.Number, Sequence], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -315,7 +326,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -390,7 +401,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -480,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform): def __init__( self, distortion_scale: float = 0.5, - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, p: float = 0.5, ) -> None: @@ -540,7 +551,7 @@ def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() @@ -606,9 +617,9 @@ def __init__( def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( - has_all(flat_inputs, features.BoundingBox) - and has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor) - and has_any(flat_inputs, features.Label, features.OneHotLabel) + has_all(flat_inputs, datapoints.BoundingBox) + and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) + and has_any(flat_inputs, datapoints.Label, datapoints.OneHotLabel) ): raise TypeError( f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " @@ -646,7 +657,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # check for any valid boxes with centers within the crop area xyxy_bboxes = F.convert_format_bounding_box( - bboxes.as_subclass(torch.Tensor), bboxes.format, features.BoundingBoxFormat.XYXY + bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY ) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) @@ -671,19 +682,19 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: is_within_crop_area = params["is_within_crop_area"] - if isinstance(inpt, (features.Label, features.OneHotLabel)): + if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel)): return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) - if isinstance(output, features.BoundingBox): + if isinstance(output, datapoints.BoundingBox): bboxes = output[is_within_crop_area] bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) - output = features.BoundingBox.wrap_like(output, bboxes) - elif isinstance(output, features.Mask): + output = datapoints.BoundingBox.wrap_like(output, bboxes) + elif isinstance(output, datapoints.Mask): # apply is_within_crop_area if mask is one-hot encoded masks = output[is_within_crop_area] - output = features.Mask.wrap_like(output, masks) + output = datapoints.Mask.wrap_like(output, masks) return output @@ -751,7 +762,7 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, + fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() @@ -764,13 +775,19 @@ def __init__( self.padding_mode = padding_mode def _check_inputs(self, flat_inputs: List[Any]) -> None: - if not has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video): + if not has_any( + flat_inputs, + PIL.Image.Image, + datapoints.Image, + is_simple_tensor, + datapoints.Video, + ): raise TypeError( f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." ) - if has_any(flat_inputs, features.BoundingBox) and not has_any( - flat_inputs, features.Label, features.OneHotLabel + if has_any(flat_inputs, datapoints.BoundingBox) and not has_any( + flat_inputs, datapoints.Label, datapoints.OneHotLabel ): raise TypeError( f"If a BoundingBox is contained in the input sample, " @@ -809,7 +826,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: ) bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) height_and_width = F.convert_format_bounding_box( - bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH + bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH )[..., 2:] is_valid = torch.all(height_and_width > 0, dim=-1) else: @@ -842,10 +859,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["is_valid"] is not None: - if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)): + if isinstance(inpt, (datapoints.Label, datapoints.OneHotLabel, datapoints.Mask)): inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] - elif isinstance(inpt, features.BoundingBox): - inpt = features.BoundingBox.wrap_like( + elif isinstance(inpt, datapoints.BoundingBox): + inpt = datapoints.BoundingBox.wrap_like( inpt, F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size), ) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 4a85175e901..6ad9e041098 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -3,38 +3,41 @@ import PIL.Image import torch -from torchvision.prototype import features + +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform +from .utils import is_simple_tensor + class ConvertBoundingBoxFormat(Transform): - _transformed_types = (features.BoundingBox,) + _transformed_types = (datapoints.BoundingBox,) - def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: + def __init__(self, format: Union[str, datapoints.BoundingBoxFormat]) -> None: super().__init__() if isinstance(format, str): - format = features.BoundingBoxFormat[format] + format = datapoints.BoundingBoxFormat[format] self.format = format - def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: + def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: # We need to unwrap here to avoid unnecessary `__torch_function__` calls, # since `convert_format_bounding_box` does not have a dispatcher function that would do that for us output = F.convert_format_bounding_box( inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"] ) - return features.BoundingBox.wrap_like(inpt, output, format=params["format"]) + return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"]) class ConvertDtype(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype def _transform( - self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] - ) -> Union[features.TensorImageType, features.TensorVideoType]: + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] + ) -> Union[datapoints.TensorImageType, datapoints.TensorVideoType]: return F.convert_dtype(inpt, self.dtype) @@ -44,36 +47,41 @@ def _transform( class ConvertColorSpace(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video) + _transformed_types = ( + is_simple_tensor, + datapoints.Image, + PIL.Image.Image, + datapoints.Video, + ) def __init__( self, - color_space: Union[str, features.ColorSpace], - old_color_space: Optional[Union[str, features.ColorSpace]] = None, + color_space: Union[str, datapoints.ColorSpace], + old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None, ) -> None: super().__init__() if isinstance(color_space, str): - color_space = features.ColorSpace.from_str(color_space) + color_space = datapoints.ColorSpace.from_str(color_space) self.color_space = color_space if isinstance(old_color_space, str): - old_color_space = features.ColorSpace.from_str(old_color_space) + old_color_space = datapoints.ColorSpace.from_str(old_color_space) self.old_color_space = old_color_space def _transform( - self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] - ) -> Union[features.ImageType, features.VideoType]: + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space) class ClampBoundingBoxes(Transform): - _transformed_types = (features.BoundingBox,) + _transformed_types = (datapoints.BoundingBox,) - def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: + def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: # We need to unwrap here to avoid unnecessary `__torch_function__` calls, # since `clamp_bounding_box` does not have a dispatcher function that would do that for us output = F.clamp_bounding_box( inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size ) - return features.BoundingBox.wrap_like(inpt, output) + return datapoints.BoundingBox.wrap_like(inpt, output) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index e50d9cff0dc..70a695199fc 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -3,12 +3,13 @@ import PIL.Image import torch + from torchvision.ops import remove_small_boxes -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size -from .utils import has_any, query_bounding_box +from .utils import has_any, is_simple_tensor, query_bounding_box class Identity(Transform): @@ -38,7 +39,7 @@ def extra_repr(self) -> str: class LinearTransformation(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): super().__init__() @@ -67,7 +68,7 @@ def _check_inputs(self, sample: Any) -> Any: raise TypeError("LinearTransformation does not work on PIL Images") def _transform( - self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] ) -> torch.Tensor: # Image instance after linear transformation is not Image anymore due to unknown data range # Thus we will return Tensor for input Image @@ -93,7 +94,7 @@ def _transform( class Normalize(Transform): - _transformed_types = (features.Image, features.is_simple_tensor, features.Video) + _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): super().__init__() @@ -106,7 +107,7 @@ def _check_inputs(self, sample: Any) -> Any: raise TypeError(f"{type(self).__name__}() does not support PIL images.") def _transform( - self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] ) -> torch.Tensor: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) @@ -158,7 +159,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class PermuteDimensions(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: super().__init__() @@ -167,7 +168,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] self.dims = dims def _transform( - self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] ) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: @@ -176,7 +177,7 @@ def _transform( class TransposeDimensions(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: super().__init__() @@ -185,7 +186,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i self.dims = dims def _transform( - self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] ) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: @@ -194,7 +195,7 @@ def _transform( class RemoveSmallBoundingBoxes(Transform): - _transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel) + _transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel) def __init__(self, min_size: float = 1.0) -> None: super().__init__() @@ -210,7 +211,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: bounding_box = F.convert_format_bounding_box( bounding_box.as_subclass(torch.Tensor), old_format=bounding_box.format, - new_format=features.BoundingBoxFormat.XYXY, + new_format=datapoints.BoundingBoxFormat.XYXY, ) valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) diff --git a/torchvision/prototype/transforms/_temporal.py b/torchvision/prototype/transforms/_temporal.py index 46293c25131..62fe7f4edf5 100644 --- a/torchvision/prototype/transforms/_temporal.py +++ b/torchvision/prototype/transforms/_temporal.py @@ -1,16 +1,18 @@ from typing import Any, Dict -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform +from torchvision.prototype.transforms.utils import is_simple_tensor + class UniformTemporalSubsample(Transform): - _transformed_types = (features.is_simple_tensor, features.Video) + _transformed_types = (is_simple_tensor, datapoints.Video) def __init__(self, num_samples: int, temporal_dim: int = -4): super().__init__() self.num_samples = num_samples self.temporal_dim = temporal_dim - def _transform(self, inpt: features.VideoType, params: Dict[str, Any]) -> features.VideoType: + def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType: return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 30b92259b93..01908650fb4 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -5,23 +5,26 @@ import torch from torch.nn.functional import one_hot -from torchvision.prototype import features + +from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform +from torchvision.prototype.transforms.utils import is_simple_tensor + class LabelToOneHot(Transform): - _transformed_types = (features.Label,) + _transformed_types = (datapoints.Label,) def __init__(self, num_categories: int = -1): super().__init__() self.num_categories = num_categories - def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.OneHotLabel: + def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) - return features.OneHotLabel(output, categories=inpt.categories) + return datapoints.OneHotLabel(output, categories=inpt.categories) def extra_repr(self) -> str: if self.num_categories == -1: @@ -38,16 +41,16 @@ def _transform(self, inpt: Union[PIL.Image.Image], params: Dict[str, Any]) -> to class ToImageTensor(Transform): - _transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] - ) -> features.Image: + ) -> datapoints.Image: return F.to_image_tensor(inpt) # type: ignore[no-any-return] class ToImagePIL(Transform): - _transformed_types = (features.is_simple_tensor, features.Image, np.ndarray) + _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) def __init__(self, mode: Optional[str] = None) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 60f64898624..cbf8992300e 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -3,8 +3,8 @@ from collections import defaultdict from typing import Any, Dict, Sequence, Type, TypeVar, Union -from torchvision.prototype import features -from torchvision.prototype.features._feature import FillType, FillTypeJIT +from torchvision.prototype import datapoints +from torchvision.prototype.datapoints._datapoint import FillType, FillTypeJIT from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 @@ -54,7 +54,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: return defaultdict(functools.partial(_default_arg, default)) -def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: +def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index c6d48d381f5..9f4a248089d 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -3,7 +3,7 @@ import PIL.Image import torch -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.transforms.functional import pil_to_tensor, to_pil_image @@ -33,28 +33,28 @@ def erase_video( def erase( - inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, features.Image): + elif isinstance(inpt, datapoints.Image): output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) - return features.Image.wrap_like(inpt, output) - elif isinstance(inpt, features.Video): + return datapoints.Image.wrap_like(inpt, output) + elif isinstance(inpt, datapoints.Video): output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) - return features.Video.wrap_like(inpt, output) + return datapoints.Video.wrap_like(inpt, output) elif isinstance(inpt, PIL.Image.Image): return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index dff640586a6..618968cbb48 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,7 +1,7 @@ import PIL.Image import torch from torch.nn.functional import conv2d -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value @@ -37,16 +37,18 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) -def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_brightness(brightness_factor=brightness_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -76,16 +78,18 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) -def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_saturation(saturation_factor=saturation_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -115,16 +119,18 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) -def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_contrast(contrast_factor=contrast_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -188,16 +194,18 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) -def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -300,16 +308,18 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image_tensor(video, hue_factor=hue_factor) -def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_hue(hue_factor=hue_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_hue_image_pil(inpt, hue_factor=hue_factor) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -340,16 +350,18 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) -def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.adjust_gamma(gamma=gamma, gain=gain) elif isinstance(inpt, PIL.Image.Image): return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -374,16 +386,18 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: return posterize_image_tensor(video, bits=bits) -def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return posterize_image_tensor(inpt, bits=bits) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.posterize(bits=bits) elif isinstance(inpt, PIL.Image.Image): return posterize_image_pil(inpt, bits=bits) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -402,16 +416,18 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: return solarize_image_tensor(video, threshold=threshold) -def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return solarize_image_tensor(inpt, threshold=threshold) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.solarize(threshold=threshold) elif isinstance(inpt, PIL.Image.Image): return solarize_image_pil(inpt, threshold=threshold) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -452,16 +468,18 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: return autocontrast_image_tensor(video) -def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return autocontrast_image_tensor(inpt) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.autocontrast() elif isinstance(inpt, PIL.Image.Image): return autocontrast_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -542,16 +560,18 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: return equalize_image_tensor(video) -def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return equalize_image_tensor(inpt) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.equalize() elif isinstance(inpt, PIL.Image.Image): return equalize_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -573,15 +593,17 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image_tensor(video) -def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return invert_image_tensor(inpt) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.invert() elif isinstance(inpt, PIL.Image.Image): return invert_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index e28bc45654c..25b54917b33 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -4,16 +4,16 @@ import PIL.Image import torch -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.transforms import functional as _F @torch.jit.unused def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: call = ", num_output_channels=3" if num_output_channels == 3 else "" - replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)" + replacement = "convert_color_space(..., color_space=datapoints.ColorSpace.GRAY)" if num_output_channels == 3: - replacement = f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB)" + replacement = f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB)" warnings.warn( f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. " f"Instead, please use `{replacement}`.", @@ -23,25 +23,25 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima def rgb_to_grayscale( - inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], num_output_channels: int = 1 -) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)): inpt = inpt.as_subclass(torch.Tensor) old_color_space = None elif isinstance(inpt, torch.Tensor): - old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] + old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] else: old_color_space = None call = ", num_output_channels=3" if num_output_channels == 3 else "" replacement = ( - f"convert_color_space(..., color_space=features.ColorSpace.GRAY" - f"{f', old_color_space=features.ColorSpace.{old_color_space}' if old_color_space is not None else ''})" + f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY" + f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})" ) if num_output_channels == 3: replacement = ( - f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB" - f"{f', old_color_space=features.ColorSpace.GRAY' if old_color_space is not None else ''})" + f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB" + f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})" ) warnings.warn( f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. " @@ -60,7 +60,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]: +def get_image_size(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 60f931d5fb5..cef68d66ee9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -7,7 +7,7 @@ import torch from torch.nn.functional import grid_sample, interpolate, pad as torch_pad -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional import ( _compute_resized_output_size as __compute_resized_output_size, @@ -34,17 +34,17 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, spatial_size: Tuple[int, int] + bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: shape = bounding_box.shape bounding_box = bounding_box.clone().reshape(-1, 4) - if format == features.BoundingBoxFormat.XYXY: + if format == datapoints.BoundingBoxFormat.XYXY: bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_() - elif format == features.BoundingBoxFormat.XYWH: + elif format == datapoints.BoundingBoxFormat.XYWH: bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_() - else: # format == features.BoundingBoxFormat.CXCYWH: + else: # format == datapoints.BoundingBoxFormat.CXCYWH: bounding_box[:, 0].sub_(spatial_size[1]).neg_() return bounding_box.reshape(shape) @@ -54,16 +54,18 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image_tensor(video) -def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return horizontal_flip_image_tensor(inpt) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.horizontal_flip() elif isinstance(inpt, PIL.Image.Image): return horizontal_flip_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -80,17 +82,17 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: def vertical_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, spatial_size: Tuple[int, int] + bounding_box: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: shape = bounding_box.shape bounding_box = bounding_box.clone().reshape(-1, 4) - if format == features.BoundingBoxFormat.XYXY: + if format == datapoints.BoundingBoxFormat.XYXY: bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_() - elif format == features.BoundingBoxFormat.XYWH: + elif format == datapoints.BoundingBoxFormat.XYWH: bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_() - else: # format == features.BoundingBoxFormat.CXCYWH: + else: # format == datapoints.BoundingBoxFormat.CXCYWH: bounding_box[:, 1].sub_(spatial_size[0]).neg_() return bounding_box.reshape(shape) @@ -100,16 +102,18 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: return vertical_flip_image_tensor(video) -def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return vertical_flip_image_tensor(inpt) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.vertical_flip() elif isinstance(inpt, PIL.Image.Image): return vertical_flip_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -221,15 +225,17 @@ def resize_video( def resize( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, PIL.Image.Image): if antialias is not None and not antialias: @@ -237,7 +243,7 @@ def resize( return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -392,7 +398,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in def _apply_grid_transform( - float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: features.FillTypeJIT + float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints.FillTypeJIT ) -> torch.Tensor: shape = float_img.shape @@ -428,7 +434,7 @@ def _assert_grid_transform_inputs( image: torch.Tensor, matrix: Optional[List[float]], interpolation: str, - fill: features.FillTypeJIT, + fill: datapoints.FillTypeJIT, supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ) -> None: @@ -491,7 +497,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if image.numel() == 0: @@ -545,7 +551,7 @@ def affine_image_pil( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) @@ -637,7 +643,7 @@ def _affine_bounding_box_xyxy( def affine_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int], angle: Union[int, float], translate: List[float], @@ -648,7 +654,7 @@ def affine_bounding_box( original_shape = bounding_box.shape bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) @@ -656,7 +662,7 @@ def affine_bounding_box( # out_bboxes should be of shape [N boxes, 4] return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) @@ -666,7 +672,7 @@ def affine_mask( translate: List[float], scale: float, shear: List[float], - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -699,7 +705,7 @@ def affine_video( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: return affine_image_tensor( @@ -715,17 +721,19 @@ def affine_video( def affine( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, angle: Union[int, float], translate: List[float], scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, center: Optional[List[float]] = None, -) -> features.InputTypeJIT: +) -> datapoints.InputTypeJIT: # TODO: consider deprecating integers from angle and shear on the future - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return affine_image_tensor( inpt, angle, @@ -736,7 +744,7 @@ def affine( fill=fill, center=center, ) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.affine( angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center ) @@ -753,7 +761,7 @@ def affine( ) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -764,7 +772,7 @@ def rotate_image_tensor( interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: shape = image.shape num_channels, height, width = shape[-3:] @@ -811,7 +819,7 @@ def rotate_image_pil( interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> PIL.Image.Image: if center is not None and expand: warnings.warn("The provided center argument has no effect on the result if expand is True") @@ -824,7 +832,7 @@ def rotate_image_pil( def rotate_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int], angle: float, expand: bool = False, @@ -836,7 +844,7 @@ def rotate_bounding_box( original_shape = bounding_box.shape bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) out_bboxes, spatial_size = _affine_bounding_box_xyxy( @@ -852,7 +860,7 @@ def rotate_bounding_box( return ( convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape), spatial_size, ) @@ -863,7 +871,7 @@ def rotate_mask( angle: float, expand: bool = False, center: Optional[List[float]] = None, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -892,28 +900,30 @@ def rotate_video( interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) def rotate( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: features.FillTypeJIT = None, -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + fill: datapoints.FillTypeJIT = None, +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, PIL.Image.Image): return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -945,7 +955,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: def pad_image_tensor( image: torch.Tensor, padding: Union[int, List[int]], - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses @@ -1047,7 +1057,7 @@ def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant", - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: if fill is None: fill = 0 @@ -1071,7 +1081,7 @@ def pad_mask( def pad_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int], padding: Union[int, List[int]], padding_mode: str = "constant", @@ -1082,7 +1092,7 @@ def pad_bounding_box( left, right, top, bottom = _parse_pad_padding(padding) - if format == features.BoundingBoxFormat.XYXY: + if format == datapoints.BoundingBoxFormat.XYXY: pad = [left, top, left, top] else: pad = [left, top, 0, 0] @@ -1098,28 +1108,30 @@ def pad_bounding_box( def pad_video( video: torch.Tensor, padding: Union[int, List[int]], - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) def pad( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, padding: Union[int, List[int]], - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, padding_mode: str = "constant", -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.pad(padding, fill=fill, padding_mode=padding_mode) elif isinstance(inpt, PIL.Image.Image): return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1147,7 +1159,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid def crop_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, top: int, left: int, height: int, @@ -1155,7 +1167,7 @@ def crop_bounding_box( ) -> Tuple[torch.Tensor, Tuple[int, int]]: # Crop or implicit pad if left and/or top have negative values: - if format == features.BoundingBoxFormat.XYXY: + if format == datapoints.BoundingBoxFormat.XYXY: sub = [left, top, left, top] else: sub = [left, top, 0, 0] @@ -1184,16 +1196,18 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image_tensor(video, top, left, height, width) -def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return crop_image_tensor(inpt, top, left, height, width) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.crop(top, left, height, width) elif isinstance(inpt, PIL.Image.Image): return crop_image_pil(inpt, top, left, height, width) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1250,7 +1264,7 @@ def perspective_image_tensor( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) @@ -1299,7 +1313,7 @@ def perspective_image_pil( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: InterpolationMode = InterpolationMode.BICUBIC, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> PIL.Image.Image: perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) @@ -1308,7 +1322,7 @@ def perspective_image_pil( def perspective_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], coefficients: Optional[List[float]] = None, @@ -1320,7 +1334,7 @@ def perspective_bounding_box( original_shape = bounding_box.shape bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 @@ -1390,7 +1404,7 @@ def perspective_bounding_box( # out_bboxes should be of shape [N boxes, 4] return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) @@ -1398,7 +1412,7 @@ def perspective_mask( mask: torch.Tensor, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -1422,7 +1436,7 @@ def perspective_video( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: return perspective_image_tensor( @@ -1431,18 +1445,20 @@ def perspective_video( def perspective( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, coefficients: Optional[List[float]] = None, -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return perspective_image_tensor( inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients ) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.perspective( startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients ) @@ -1452,7 +1468,7 @@ def perspective( ) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1461,7 +1477,7 @@ def elastic_image_tensor( image: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: if image.numel() == 0: return image @@ -1498,7 +1514,7 @@ def elastic_image_pil( image: PIL.Image.Image, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(image) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) @@ -1519,7 +1535,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch. def elastic_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, displacement: torch.Tensor, ) -> torch.Tensor: if bounding_box.numel() == 0: @@ -1530,7 +1546,7 @@ def elastic_bounding_box( original_shape = bounding_box.shape bounding_box = ( - convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ).reshape(-1, 4) # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it @@ -1558,14 +1574,14 @@ def elastic_bounding_box( out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -1585,26 +1601,28 @@ def elastic_video( video: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, + fill: datapoints.FillTypeJIT = None, ) -> torch.Tensor: return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) def elastic( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + fill: datapoints.FillTypeJIT = None, +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.elastic(displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, PIL.Image.Image): return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1677,7 +1695,7 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL def center_crop_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int], output_size: List[int], ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -1705,16 +1723,18 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens return center_crop_image_tensor(video, output_size) -def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return center_crop_image_tensor(inpt, output_size) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.center_crop(output_size) elif isinstance(inpt, PIL.Image.Image): return center_crop_image_pil(inpt, output_size) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1749,7 +1769,7 @@ def resized_crop_image_pil( def resized_crop_bounding_box( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, + format: datapoints.BoundingBoxFormat, top: int, left: int, height: int, @@ -1788,7 +1808,7 @@ def resized_crop_video( def resized_crop( - inpt: features.InputTypeJIT, + inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, @@ -1796,18 +1816,20 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return resized_crop_image_tensor( inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation ) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) elif isinstance(inpt, PIL.Image.Image): return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1869,28 +1891,29 @@ def five_crop_video( return five_crop_image_tensor(video, size) -ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT] +ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] def five_crop( inpt: ImageOrVideoTypeJIT, size: List[int] ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: - # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop` + # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with + # `ten_crop` if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): return five_crop_image_tensor(inpt, size) - elif isinstance(inpt, features.Image): + elif isinstance(inpt, datapoints.Image): output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) - return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] - elif isinstance(inpt, features.Video): + return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + elif isinstance(inpt, datapoints.Video): output = five_crop_video(inpt.as_subclass(torch.Tensor), size) - return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] elif isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -1927,22 +1950,22 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F def ten_crop( - inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False -) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False +) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, features.Image): + elif isinstance(inpt, datapoints.Image): output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) - return [features.Image.wrap_like(inpt, item) for item in output] - elif isinstance(inpt, features.Video): + return [datapoints.Image.wrap_like(inpt, item) for item in output] + elif isinstance(inpt, datapoints.Video): output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) - return [features.Video.wrap_like(inpt, item) for item in output] + return [datapoints.Video.wrap_like(inpt, item) for item in output] elif isinstance(inpt, PIL.Image.Image): return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 4605b433b28..a6b9c773891 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -2,8 +2,8 @@ import PIL.Image import torch -from torchvision.prototype import features -from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.prototype import datapoints +from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional_tensor import _max_value @@ -23,12 +23,12 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: get_dimensions_image_pil = _FP.get_dimensions -def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]: +def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): return get_dimensions_image_tensor(inpt) - elif isinstance(inpt, (features.Image, features.Video)): + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): channels = inpt.num_channels height, width = inpt.spatial_size return [channels, height, width] @@ -36,7 +36,7 @@ def get_dimensions(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> return get_dimensions_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -59,18 +59,18 @@ def get_num_channels_video(video: torch.Tensor) -> int: return get_num_channels_image_tensor(video) -def get_num_channels(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int: +def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int: if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): return get_num_channels_image_tensor(inpt) - elif isinstance(inpt, (features.Image, features.Video)): + elif isinstance(inpt, (datapoints.Image, datapoints.Video)): return inpt.num_channels elif isinstance(inpt, PIL.Image.Image): return get_num_channels_image_pil(inpt) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -104,20 +104,22 @@ def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: @torch.jit.unused -def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]: +def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[int]: return list(bounding_box.spatial_size) -def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): +def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return get_spatial_size_image_tensor(inpt) - elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)): + elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): return list(inpt.spatial_size) elif isinstance(inpt, PIL.Image.Image): return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -126,15 +128,13 @@ def get_num_frames_video(video: torch.Tensor) -> int: return video.shape[-4] -def get_num_frames(inpt: features.VideoTypeJIT) -> int: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)): +def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int: + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): return get_num_frames_video(inpt) - elif isinstance(inpt, features.Video): + elif isinstance(inpt, datapoints.Video): return inpt.num_frames else: - raise TypeError( - f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead." - ) + raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: @@ -202,7 +202,7 @@ def clamp_bounding_box( # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth xyxy_boxes = convert_format_bounding_box( - bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True + bounding_box.clone(), old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True ) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) @@ -309,12 +309,12 @@ def convert_color_space_video( def convert_color_space( - inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, -) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): if old_color_space is None: raise RuntimeError( @@ -322,21 +322,21 @@ def convert_color_space( "the `old_color_space=...` parameter needs to be passed." ) return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space) - elif isinstance(inpt, features.Image): + elif isinstance(inpt, datapoints.Image): output = convert_color_space_image_tensor( inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space ) - return features.Image.wrap_like(inpt, output, color_space=color_space) - elif isinstance(inpt, features.Video): + return datapoints.Image.wrap_like(inpt, output, color_space=color_space) + elif isinstance(inpt, datapoints.Video): output = convert_color_space_video( inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space ) - return features.Video.wrap_like(inpt, output, color_space=color_space) + return datapoints.Video.wrap_like(inpt, output, color_space=color_space) elif isinstance(inpt, PIL.Image.Image): return convert_color_space_image_pil(inpt, color_space=color_space) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` tensor subclass, or a PIL image, " + f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) @@ -415,20 +415,19 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) - def convert_dtype( - inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float ) -> torch.Tensor: if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) ): return convert_dtype_image_tensor(inpt, dtype) - elif isinstance(inpt, features.Image): + elif isinstance(inpt, datapoints.Image): output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) - return features.Image.wrap_like(inpt, output) - elif isinstance(inpt, features.Video): + return datapoints.Image.wrap_like(inpt, output) + elif isinstance(inpt, datapoints.Video): output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) - return features.Video.wrap_like(inpt, output) + return datapoints.Video.wrap_like(inpt, output) else: raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, " - f"but got {type(inpt)} instead." + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." ) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 575e5c76c85..7799187373f 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -4,9 +4,12 @@ import PIL.Image import torch from torch.nn.functional import conv2d, pad as torch_pad -from torchvision.prototype import features + +from torchvision.prototype import datapoints from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from ..utils import is_simple_tensor + def normalize_image_tensor( image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False @@ -48,17 +51,17 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in def normalize( - inpt: Union[features.TensorImageTypeJIT, features.TensorVideoTypeJIT], + inpt: Union[datapoints.TensorImageTypeJIT, datapoints.TensorVideoTypeJIT], mean: List[float], std: List[float], inplace: bool = False, ) -> torch.Tensor: if not torch.jit.is_scripting(): - if features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video)): + if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)): inpt = inpt.as_subclass(torch.Tensor) else: raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` tensor subclass, " + f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." ) @@ -163,16 +166,18 @@ def gaussian_blur_video( def gaussian_blur( - inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> datapoints.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) + ): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, features._Feature): + elif isinstance(inpt, datapoints._datapoint.Datapoint): return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, PIL.Image.Image): return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) else: raise TypeError( - f"Input can either be a plain tensor, one of the tensor subclasses TorchVision provides, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) diff --git a/torchvision/prototype/transforms/functional/_temporal.py b/torchvision/prototype/transforms/functional/_temporal.py index 15d9918ae9f..63b3baf942e 100644 --- a/torchvision/prototype/transforms/functional/_temporal.py +++ b/torchvision/prototype/transforms/functional/_temporal.py @@ -1,6 +1,6 @@ import torch -from torchvision.prototype import features +from torchvision.prototype import datapoints def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: @@ -11,18 +11,16 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp def uniform_temporal_subsample( - inpt: features.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 -) -> features.VideoTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)): + inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 +) -> datapoints.VideoTypeJIT: + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) - elif isinstance(inpt, features.Video): + elif isinstance(inpt, datapoints.Video): if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: raise ValueError("Video inputs must have temporal_dim equivalent to -4") output = uniform_temporal_subsample_video( inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim ) - return features.Video.wrap_like(inpt, output) + return datapoints.Video.wrap_like(inpt, output) else: - raise TypeError( - f"Input can either be a plain tensor or a `Video` tensor subclass, but got {type(inpt)} instead." - ) + raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index ff2a8bdf4d1..286aa7485da 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -3,12 +3,12 @@ import numpy as np import PIL.Image import torch -from torchvision.prototype import features +from torchvision.prototype import datapoints from torchvision.transforms import functional as _F @torch.jit.unused -def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image: +def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: if isinstance(inpt, np.ndarray): output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() elif isinstance(inpt, PIL.Image.Image): @@ -17,7 +17,7 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> f output = inpt else: raise TypeError(f"Input can either be a numpy array or a PIL image, but got {type(inpt)} instead.") - return features.Image(output) + return datapoints.Image(output) to_image_pil = _F.to_pil_image diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py index 73ab3466154..9ab2ed2602b 100644 --- a/torchvision/prototype/transforms/utils.py +++ b/torchvision/prototype/transforms/utils.py @@ -1,14 +1,22 @@ +from __future__ import annotations + from typing import Any, Callable, List, Tuple, Type, Union import PIL.Image +import torch from torchvision._utils import sequence_to_str -from torchvision.prototype import features +from torchvision.prototype import datapoints +from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size -def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: - bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)] +def is_simple_tensor(inpt: Any) -> bool: + return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) + + +def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: + bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBox)] if not bounding_boxes: raise TypeError("No bounding box was found in the sample") elif len(bounding_boxes) > 1: @@ -20,7 +28,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(inpt) + if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt) } if not chws: raise TypeError("No image or video was found in the sample") @@ -34,8 +42,10 @@ def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: sizes = { tuple(get_spatial_size(inpt)) for inpt in flat_inputs - if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox)) - or features.is_simple_tensor(inpt) + if isinstance( + inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBox) + ) + or is_simple_tensor(inpt) } if not sizes: raise TypeError("No image, video, mask or bounding box was found in the sample")