diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 5f0daa4bee0..3d34383319f 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -640,14 +640,14 @@ def __init__( self.condition = condition or (lambda args_kwargs: True) -def mark_framework_limitation(test_id, reason): +def mark_framework_limitation(test_id, reason, condition=None): # The purpose of this function is to have a single entry point for skip marks that are only there, because the test # framework cannot handle the kernel in general or a specific parameter combination. # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is # still justified. # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus, # we are wasting CI resources for no reason for most of the time - return TestMark(test_id, pytest.mark.skip(reason=reason)) + return TestMark(test_id, pytest.mark.skip(reason=reason), condition=condition) class InfoBase: diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 4c0af6703ec..ce80658ce8b 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -12,6 +12,7 @@ from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, + BoundingBoxLoader, get_num_channels, ImageLoader, InfoBase, @@ -25,6 +26,7 @@ make_video_loader, make_video_loaders, mark_framework_limitation, + TensorLoader, TestMark, ) from torch.utils._pytree import tree_map @@ -2010,8 +2012,15 @@ def sample_inputs_adjust_saturation_video(): def sample_inputs_clamp_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): + yield ArgsKwargs(bounding_box_loader) + + simple_tensor_loader = TensorLoader( + fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor), + shape=bounding_box_loader.shape, + dtype=bounding_box_loader.dtype, + ) yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size + simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size ) @@ -2020,6 +2029,19 @@ def sample_inputs_clamp_bounding_box(): F.clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box, logs_usage=True, + test_marks=[ + mark_framework_limitation( + ("TestKernels", "test_scripted_vs_eager"), + reason=( + "The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a " + "`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor " + "`spatial_size` was passed" + ), + condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader) + and arg_kwargs.kwargs.get("format") is None + and arg_kwargs.kwargs.get("spatial_size") is None, + ) + ], ) ) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 649620eda62..948143771ab 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -155,12 +155,14 @@ def _unbatch(self, batch, *, data_dims): if batched_tensor.ndim == data_dims: return batch - return [ - self._unbatch(unbatched, data_dims=data_dims) - for unbatched in ( - batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] - ) - ] + unbatcheds = [] + for unbatched in ( + batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] + ): + if isinstance(batch, datapoints._datapoint.Datapoint): + unbatched = type(batch).wrap_like(batch, unbatched) + unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims)) + return unbatcheds @sample_inputs @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -558,6 +560,36 @@ def assert_samples_from_standard_normal(t): assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) +class TestClampBoundingBox: + @pytest.mark.parametrize( + "metadata", + [ + dict(), + dict(format=datapoints.BoundingBoxFormat.XYXY), + dict(spatial_size=(1, 1)), + ], + ) + def test_simple_tensor_insufficient_metadata(self, metadata): + simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) + + with pytest.raises(ValueError, match="simple tensor"): + F.clamp_bounding_box(simple_tensor, **metadata) + + @pytest.mark.parametrize( + "metadata", + [ + dict(format=datapoints.BoundingBoxFormat.XYXY), + dict(spatial_size=(1, 1)), + dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)), + ], + ) + def test_datapoint_explicit_metadata(self, metadata): + datapoint = next(make_bounding_boxes()) + + with pytest.raises(ValueError, match="bounding box datapoint"): + F.clamp_bounding_box(datapoint, **metadata) + + # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # `prototype_transforms_kernel_infos.py` diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 1cef6eeb8f2..946c00b0ee6 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform): _transformed_types = (datapoints.BoundingBox,) def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: - # We need to unwrap here to avoid unnecessary `__torch_function__` calls, - # since `clamp_bounding_box` does not have a dispatcher function that would do that for us - output = F.clamp_bounding_box( - inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size - ) - return datapoints.BoundingBox.wrap_like(inpt, output) + return F.clamp_bounding_box(inpt) # type: ignore[return-value] diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 31d86bec256..2c5180a8644 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import PIL.Image import torch @@ -209,12 +209,9 @@ def convert_format_bounding_box( return bounding_box -def clamp_bounding_box( +def _clamp_bounding_box( bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(clamp_bounding_box) - # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth xyxy_boxes = convert_format_bounding_box( @@ -225,6 +222,29 @@ def clamp_bounding_box( return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) +def clamp_bounding_box( + inpt: datapoints.InputTypeJIT, + format: Optional[BoundingBoxFormat] = None, + spatial_size: Optional[Tuple[int, int]] = None, +) -> datapoints.InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(clamp_bounding_box) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + if format is None or spatial_size is None: + raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.") + return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size) + elif isinstance(inpt, datapoints.BoundingBox): + if format is not None or spatial_size is not None: + raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") + output = _clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size) + return datapoints.BoundingBox.wrap_like(inpt, output) + else: + raise TypeError( + f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." + ) + + def _num_value_bits(dtype: torch.dtype) -> int: if dtype == torch.uint8: return 8