From 7c5ab88ce5e39bb0134f80c5b8a4fa6ef4e5e28b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 18:47:49 +0000 Subject: [PATCH 01/14] Add SanitizeBoundingBoxes transform --- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_misc.py | 92 ++++++++++++++++---- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 132edb1b6fc..b189177c9a9 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -49,7 +49,7 @@ LinearTransformation, Normalize, PermuteDimensions, - RemoveSmallBoundingBoxes, + SanitizeBoundingBoxes, ToDtype, TransposeDimensions, ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index b398227b480..fa86db26836 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,12 +1,13 @@ import warnings +from contextlib import suppress from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image import torch +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import transforms as _transforms -from torchvision.ops import remove_small_boxes from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform @@ -225,28 +226,83 @@ def _transform( return inpt.transpose(*dims) -class RemoveSmallBoundingBoxes(Transform): - _transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel) +class SanitizeBoundingBoxes(Transform): + # This removes boxes and their corresponding labels: + # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) + # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) + _transformed_types = (datapoints.BoundingBox, datapoints.Mas) - def __init__(self, min_size: float = 1.0) -> None: + def __init__(self, min_size: float = 1.0, labels="default") -> None: super().__init__() self.min_size = min_size + self.labels = labels + + def _find_label_default_heuristic(self, inputs): + # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive + # Returns None if nothing is found + candidate_key = None + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if key.lower() == "label") + if candidate_key is None: + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) + labels = inputs.get(candidate_key) + return labels + + def forward(self, *inputs: Any) -> Any: + inputs = inputs if len(inputs) > 1 else inputs[0] + if isinstance(labels, str) and not isinstance(inputs, dict): + raise ValueError( + f"If labels is a str or 'default' (got {labels}), then the input to forward() must be a dict. " + f"Got {type(inputs)} instead" + ) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - bounding_box = query_bounding_box(flat_inputs) - - # TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to - # be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH - # format,we need to convert first just to afterwards compute the width and height again, although they were - # there in the first place for these formats. - bounding_box = F.convert_format_bounding_box( - bounding_box.as_subclass(torch.Tensor), - old_format=bounding_box.format, + labels = None + if self.labels == "default": + labels = self._find_label_default_heuristic(inputs) + elif callable(self.labels): + labels = self.labels(inputs) + elif isinstance(self.labels, str): + labels = inputs[self.labels] + else: + raise ValueError( + "labels parameter should either be a str, callable, or 'default'. " + f"Got {labels} of type {type(labels)}." + ) + + flat_inputs, spec = tree_flatten(inputs) + # TODO: this enforces one single BoundingBox entry. + # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... + # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? + boxes = query_bounding_box(flat_inputs) + + boxes = F.convert_format_bounding_box( + boxes, new_format=datapoints.BoundingBoxFormat.XYXY, ) - valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) - - return dict(valid_indices=valid_indices) + ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] + keep = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=1) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + # TODO: Also... should this is <= instead of < ??? + image_h, image_w = boxes.spatial_size + keep &= (boxes[:, 0] < image_w).all() & (boxes[:, 2] < image_w).all() + keep &= (boxes[:, 1] < image_h).all() & (boxes[:, 3] < image_h).all() + valid_indices = torch.where(keep)[0] + + params = dict(valid_indices=valid_indices, labels=labels) + flat_outputs = [ + # Even-though it may look like we're transforming all inputs, we don't: + # _transform() will only care about BoundingBoxes and the labels + self._transform(inpt, params) + for inpt in flat_inputs + ] + + return tree_unflatten(flat_outputs, spec) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return inpt.wrap_like(inpt, inpt[params["valid_indices"]]) + + if inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox): + inpt = inpt[params["valid_indices"]] + + return inpt From 26929c053173b65c634156f0f66b23994aa6194f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 19:29:07 +0000 Subject: [PATCH 02/14] Added basic test, will improve upon --- test/test_prototype_transforms.py | 29 +++++++++++++++++++++++ torchvision/prototype/transforms/_misc.py | 6 ++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 7030d2d1b2e..6fd1d3d0cbb 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2355,3 +2355,32 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): out["label"] = torch.tensor(out["label"]) assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes + + +def test_sanitize_bounding_boxes(): + + H, W = 256, 256 + boxes = datapoints.BoundingBox( + [ + [1, 1, 30, 20], + [0, 1, 10, 1], + [0, 0, 10, 10], + [1, 1, 30, 20], + [0, 1, 0, 20], + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(H, W), + ) + sample = { + "image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), + "labels": torch.arange(boxes.shape[0]), + "boxes": boxes, + "whatever": torch.rand(10), + } + + out = transforms.SanitizeBoundingBoxes()(sample) + + assert out["image"] is sample["image"] + assert out["whatever"] is sample["whatever"] + assert out["boxes"].shape[0] == out["labels"].shape[0] + assert out["labels"].tolist() == [0, 2, 3] \ No newline at end of file diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index fa86db26836..80899f03c15 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -230,7 +230,7 @@ class SanitizeBoundingBoxes(Transform): # This removes boxes and their corresponding labels: # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) - _transformed_types = (datapoints.BoundingBox, datapoints.Mas) + _transformed_types = (datapoints.BoundingBox) def __init__(self, min_size: float = 1.0, labels="default") -> None: super().__init__() @@ -240,6 +240,7 @@ def __init__(self, min_size: float = 1.0, labels="default") -> None: def _find_label_default_heuristic(self, inputs): # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive # Returns None if nothing is found + labels = None candidate_key = None with suppress(StopIteration): candidate_key = next(key for key in inputs.keys() if key.lower() == "label") @@ -251,13 +252,12 @@ def _find_label_default_heuristic(self, inputs): def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] - if isinstance(labels, str) and not isinstance(inputs, dict): + if isinstance(self.labels, str) and not isinstance(inputs, dict): raise ValueError( f"If labels is a str or 'default' (got {labels}), then the input to forward() must be a dict. " f"Got {type(inputs)} instead" ) - labels = None if self.labels == "default": labels = self._find_label_default_heuristic(inputs) elif callable(self.labels): From 9ae43b210465810e0c1d51b787ab86f0664e3923 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 20:51:52 +0000 Subject: [PATCH 03/14] Address comments + add tests --- test/test_prototype_transforms.py | 53 ++++++++++++++++++----- torchvision/prototype/transforms/_misc.py | 52 ++++++++++++++-------- 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 6fd1d3d0cbb..afe66fac485 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,6 @@ import itertools import pathlib +import random import re import warnings from collections import defaultdict @@ -2357,20 +2358,39 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes -def test_sanitize_bounding_boxes(): +@pytest.mark.parametrize("min_size", (1, 10)) +def test_sanitize_bounding_boxes(min_size): + + H, W = 256, 128 + + boxes_and_validity = [ + ([0, 1, 10, 1], False), # Y1 == Y2 + ([0, 1, 0, 20], False), # X1 == X2 + ([0, 0, min_size - 1, 10], False), # H < min_size + ([0, 0, 10, min_size - 1], False), # W < min_size + ([0, 0, 10, H + 1], False), # Y2 > H + ([0, 0, W + 1, 10], False), # X2 > W + ([-1, 1, 10, 20], False), # any < 0 + ([0, 0, -1, 20], False), # any < 0 + ([0, 0, -10, -1], False), # any < 0 + ([0, 0, min_size, 10], True), # H < min_size + ([0, 0, 10, min_size], True), # W < min_size + ([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1? + ([1, 1, 30, 20], True), + ([0, 0, 10, 10], True), + ([1, 1, 30, 20], True), + ] + + random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases + boxes, is_valid_mask = zip(*boxes_and_validity) + valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] - H, W = 256, 256 boxes = datapoints.BoundingBox( - [ - [1, 1, 30, 20], - [0, 1, 10, 1], - [0, 0, 10, 10], - [1, 1, 30, 20], - [0, 1, 0, 20], - ], + boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(H, W), ) + sample = { "image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), "labels": torch.arange(boxes.shape[0]), @@ -2378,9 +2398,20 @@ def test_sanitize_bounding_boxes(): "whatever": torch.rand(10), } - out = transforms.SanitizeBoundingBoxes()(sample) + out = transforms.SanitizeBoundingBoxes(min_size=min_size)(sample) assert out["image"] is sample["image"] assert out["whatever"] is sample["whatever"] + assert type(out["labels"]) is type(sample["labels"]) + + out["labels"] = torch.tensor(out["labels"]) assert out["boxes"].shape[0] == out["labels"].shape[0] - assert out["labels"].tolist() == [0, 2, 3] \ No newline at end of file + + # This works because we conveniently set labels to arange(num_boxes) + assert out["labels"].tolist() == valid_indices + + +# def test_sanitize_bounding_boxes_errors(): + +# with pytest.raises(ValueError, match=) +# out = transforms.SanitizeBoundingBoxes(min_size=min_size)(sample) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 80899f03c15..ba8f9042eb8 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -230,12 +230,27 @@ class SanitizeBoundingBoxes(Transform): # This removes boxes and their corresponding labels: # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) - _transformed_types = (datapoints.BoundingBox) + _transformed_types = datapoints.BoundingBox def __init__(self, min_size: float = 1.0, labels="default") -> None: super().__init__() + + if min_size < 1: + raise ValueError(f"min_size must be >= 1, got {min_size}.") self.min_size = min_size + self.labels = labels + if labels == "default": + self._get_labels = self._find_label_default_heuristic + elif callable(self.labels): + self._get_labels = labels + elif isinstance(self.labels, str): + self._get_labels = lambda inputs: inputs[labels] + else: + raise ValueError( + "labels parameter should either be a str, callable, or 'default'. " + f"Got {labels} of type {type(labels)}." + ) def _find_label_default_heuristic(self, inputs): # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive @@ -243,32 +258,29 @@ def _find_label_default_heuristic(self, inputs): labels = None candidate_key = None with suppress(StopIteration): - candidate_key = next(key for key in inputs.keys() if key.lower() == "label") + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") if candidate_key is None: with suppress(StopIteration): candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) labels = inputs.get(candidate_key) + + if labels is None: + raise ValueError( + "Could not infer where the labels are in the sample. Try passing a callable as the label parameter?" + ) return labels def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] + if isinstance(self.labels, str) and not isinstance(inputs, dict): raise ValueError( - f"If labels is a str or 'default' (got {labels}), then the input to forward() must be a dict. " + f"If labels is a str or 'default' (got {self.labels}), then the input to forward() must be a dict. " f"Got {type(inputs)} instead" ) - - if self.labels == "default": - labels = self._find_label_default_heuristic(inputs) - elif callable(self.labels): - labels = self.labels(inputs) - elif isinstance(self.labels, str): - labels = inputs[self.labels] - else: - raise ValueError( - "labels parameter should either be a str, callable, or 'default'. " - f"Got {labels} of type {type(labels)}." - ) + labels = self._get_labels(inputs) + if not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") flat_inputs, spec = tree_flatten(inputs) # TODO: this enforces one single BoundingBox entry. @@ -276,6 +288,11 @@ def forward(self, *inputs: Any) -> Any: # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? boxes = query_bounding_box(flat_inputs) + if boxes.shape[-2] != labels.shape[0]: + raise ValueError( + f"Number of boxes ({boxes.shape[-2]}) and number of labels ({labels.shape[0]}) do not match." + ) + boxes = F.convert_format_bounding_box( boxes, new_format=datapoints.BoundingBoxFormat.XYXY, @@ -284,10 +301,9 @@ def forward(self, *inputs: Any) -> Any: keep = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=1) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? - # TODO: Also... should this is <= instead of < ??? image_h, image_w = boxes.spatial_size - keep &= (boxes[:, 0] < image_w).all() & (boxes[:, 2] < image_w).all() - keep &= (boxes[:, 1] < image_h).all() & (boxes[:, 3] < image_h).all() + keep &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) + keep &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) valid_indices = torch.where(keep)[0] params = dict(valid_indices=valid_indices, labels=labels) From 2ac342adeb00ad5a9204abc74e2488964c747642 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 20:58:05 +0000 Subject: [PATCH 04/14] A little more --- test/test_prototype_transforms.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index afe66fac485..abc231aa43a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2359,7 +2359,8 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): @pytest.mark.parametrize("min_size", (1, 10)) -def test_sanitize_bounding_boxes(min_size): +@pytest.mark.parametrize("labels_param", ("default", "labels", lambda inputs: inputs["labels"])) +def test_sanitize_bounding_boxes(min_size, labels_param): H, W = 256, 128 @@ -2398,7 +2399,7 @@ def test_sanitize_bounding_boxes(min_size): "whatever": torch.rand(10), } - out = transforms.SanitizeBoundingBoxes(min_size=min_size)(sample) + out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels=labels_param)(sample) assert out["image"] is sample["image"] assert out["whatever"] is sample["whatever"] @@ -2413,5 +2414,11 @@ def test_sanitize_bounding_boxes(min_size): # def test_sanitize_bounding_boxes_errors(): -# with pytest.raises(ValueError, match=) -# out = transforms.SanitizeBoundingBoxes(min_size=min_size)(sample) +# with pytest.raises(ValueError, match="min_size must be >= 1"): +# transforms.SanitizeBoundingBoxes(min_size=0) +# with pytest.raises(ValueError, match="labels parameter should either be a str"): +# transforms.SanitizeBoundingBoxes(labels=12) + +# with pytest.raises(ValueError, match="labels parameter should either be a str"): +# bad_labels = +# transforms.SanitizeBoundingBoxes() From 87a849ecde66e1886b6a1a3b3e1e2a5d19928941 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 21:09:35 +0000 Subject: [PATCH 05/14] Add more input check tests --- test/test_prototype_transforms.py | 41 ++++++++++++++++++----- torchvision/prototype/transforms/_misc.py | 3 +- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index abc231aa43a..0ef678ef1a7 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2412,13 +2412,38 @@ def test_sanitize_bounding_boxes(min_size, labels_param): assert out["labels"].tolist() == valid_indices -# def test_sanitize_bounding_boxes_errors(): +@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) +def test_sanitize_bounding_boxes_default_heuristic(key): + labels = torch.arange(10) + d = {key: labels} + assert transforms.SanitizeBoundingBoxes._find_label_default_heuristic(d) is labels -# with pytest.raises(ValueError, match="min_size must be >= 1"): -# transforms.SanitizeBoundingBoxes(min_size=0) -# with pytest.raises(ValueError, match="labels parameter should either be a str"): -# transforms.SanitizeBoundingBoxes(labels=12) -# with pytest.raises(ValueError, match="labels parameter should either be a str"): -# bad_labels = -# transforms.SanitizeBoundingBoxes() +def test_sanitize_bounding_boxes_errors(): + + good_bbox = datapoints.BoundingBox( + [[0, 0, 10, 10]], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(20, 20), + ) + + with pytest.raises(ValueError, match="min_size must be >= 1"): + transforms.SanitizeBoundingBoxes(min_size=0) + with pytest.raises(ValueError, match="labels parameter should either be a str"): + transforms.SanitizeBoundingBoxes(labels=12) + + with pytest.raises(ValueError, match="Could not infer where the labels are"): + bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} + transforms.SanitizeBoundingBoxes()(bad_labels_key) + + with pytest.raises(ValueError, match="If labels is a str or 'default'"): + not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0])) + transforms.SanitizeBoundingBoxes()(not_a_dict) + + with pytest.raises(ValueError, match="must be a tensor"): + not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} + transforms.SanitizeBoundingBoxes()(not_a_tensor) + + with pytest.raises(ValueError, match="Number of boxes"): + different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} + transforms.SanitizeBoundingBoxes()(different_sizes) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index ba8f9042eb8..7aac49d7959 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -252,7 +252,8 @@ def __init__(self, min_size: float = 1.0, labels="default") -> None: f"Got {labels} of type {type(labels)}." ) - def _find_label_default_heuristic(self, inputs): + @staticmethod + def _find_label_default_heuristic(inputs): # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive # Returns None if nothing is found labels = None From 7839dd8afb6eb358ff6acc521bedd95093e137b2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 09:43:11 +0000 Subject: [PATCH 06/14] Add support for batch dimension - only 1 batch tho --- test/test_prototype_transforms.py | 31 +++++++++++++++++----- torchvision/prototype/transforms/_misc.py | 32 ++++++++++++++--------- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 0ef678ef1a7..d4e5a914071 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2360,8 +2360,8 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): @pytest.mark.parametrize("min_size", (1, 10)) @pytest.mark.parametrize("labels_param", ("default", "labels", lambda inputs: inputs["labels"])) -def test_sanitize_bounding_boxes(min_size, labels_param): - +@pytest.mark.parametrize("batch", (True, False)) +def test_sanitize_bounding_boxes(min_size, labels_param, batch): H, W = 256, 128 boxes_and_validity = [ @@ -2386,6 +2386,14 @@ def test_sanitize_bounding_boxes(min_size, labels_param): boxes, is_valid_mask = zip(*boxes_and_validity) valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] + boxes = torch.tensor(boxes) + labels = torch.arange(boxes.shape[-2]) + + if batch: + boxes = boxes[None] + labels = labels[None] + valid_indices = [valid_indices] + boxes = datapoints.BoundingBox( boxes, format=datapoints.BoundingBoxFormat.XYXY, @@ -2394,7 +2402,7 @@ def test_sanitize_bounding_boxes(min_size, labels_param): sample = { "image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), - "labels": torch.arange(boxes.shape[0]), + "labels": labels, "boxes": boxes, "whatever": torch.rand(10), } @@ -2403,10 +2411,9 @@ def test_sanitize_bounding_boxes(min_size, labels_param): assert out["image"] is sample["image"] assert out["whatever"] is sample["whatever"] - assert type(out["labels"]) is type(sample["labels"]) + assert isinstance(out["labels"], torch.Tensor) - out["labels"] = torch.tensor(out["labels"]) - assert out["boxes"].shape[0] == out["labels"].shape[0] + assert out["boxes"].shape[:-1] == out["labels"].shape # This works because we conveniently set labels to arange(num_boxes) assert out["labels"].tolist() == valid_indices @@ -2447,3 +2454,15 @@ def test_sanitize_bounding_boxes_errors(): with pytest.raises(ValueError, match="Number of boxes"): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) + + with pytest.raises(ValueError, match="boxes must be of shape"): + bad_bbox = datapoints.BoundingBox( # batch with 2 elements + [ + [[0, 0, 10, 10]], + [[0, 0, 10, 10]], + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(20, 20), + ) + different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} + transforms.SanitizeBoundingBoxes()(different_sizes) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 7aac49d7959..8b0eeec57c5 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,3 +1,4 @@ +import collections import warnings from contextlib import suppress from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union @@ -274,7 +275,7 @@ def _find_label_default_heuristic(inputs): def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] - if isinstance(self.labels, str) and not isinstance(inputs, dict): + if isinstance(self.labels, str) and not isinstance(inputs, collections.abc.Mapping): raise ValueError( f"If labels is a str or 'default' (got {self.labels}), then the input to forward() must be a dict. " f"Got {type(inputs)} instead" @@ -289,25 +290,27 @@ def forward(self, *inputs: Any) -> Any: # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? boxes = query_bounding_box(flat_inputs) - if boxes.shape[-2] != labels.shape[0]: + if boxes.ndim > 3 or (boxes.ndim == 3 and boxes.shape[0] != 1): + raise ValueError(f"boxes must be of shape (num_boxes, 4) or (1, num_boxes, 4), got {boxes.shape}") + + if boxes.shape[:-1] != labels.shape: raise ValueError( - f"Number of boxes ({boxes.shape[-2]}) and number of labels ({labels.shape[0]}) do not match." + f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) boxes = F.convert_format_bounding_box( boxes, new_format=datapoints.BoundingBoxFormat.XYXY, ) - ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - keep = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=1) + ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1] + mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=-1) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = boxes.spatial_size - keep &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) - keep &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - valid_indices = torch.where(keep)[0] + mask &= (boxes[..., 0] <= image_w) & (boxes[..., 2] <= image_w) + mask &= (boxes[..., 1] <= image_h) & (boxes[..., 3] <= image_h) - params = dict(valid_indices=valid_indices, labels=labels) + params = dict(mask=mask, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: # _transform() will only care about BoundingBoxes and the labels @@ -319,7 +322,12 @@ def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox): - inpt = inpt[params["valid_indices"]] + if not (inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox)): + return inpt + + out = inpt[params["mask"]] + if inpt.ndim != out.ndim: + # Add extra batch dim + out = out[None] - return inpt + return out From 57aaab38e03ff7aa3f159f4a2b15163f3269b067 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 09:47:24 +0000 Subject: [PATCH 07/14] renamed labels parameter into labels_getter --- test/test_prototype_transforms.py | 12 +++++----- torchvision/prototype/transforms/_misc.py | 28 +++++++++++------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d4e5a914071..c0ca50550ac 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2359,9 +2359,9 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): @pytest.mark.parametrize("min_size", (1, 10)) -@pytest.mark.parametrize("labels_param", ("default", "labels", lambda inputs: inputs["labels"])) +@pytest.mark.parametrize("labels_getter", ("default", "labels", lambda inputs: inputs["labels"])) @pytest.mark.parametrize("batch", (True, False)) -def test_sanitize_bounding_boxes(min_size, labels_param, batch): +def test_sanitize_bounding_boxes(min_size, labels_getter, batch): H, W = 256, 128 boxes_and_validity = [ @@ -2407,7 +2407,7 @@ def test_sanitize_bounding_boxes(min_size, labels_param, batch): "whatever": torch.rand(10), } - out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels=labels_param)(sample) + out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) assert out["image"] is sample["image"] assert out["whatever"] is sample["whatever"] @@ -2436,14 +2436,14 @@ def test_sanitize_bounding_boxes_errors(): with pytest.raises(ValueError, match="min_size must be >= 1"): transforms.SanitizeBoundingBoxes(min_size=0) - with pytest.raises(ValueError, match="labels parameter should either be a str"): - transforms.SanitizeBoundingBoxes(labels=12) + with pytest.raises(ValueError, match="labels_getter should either be a str"): + transforms.SanitizeBoundingBoxes(labels_getter=12) with pytest.raises(ValueError, match="Could not infer where the labels are"): bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} transforms.SanitizeBoundingBoxes()(bad_labels_key) - with pytest.raises(ValueError, match="If labels is a str or 'default'"): + with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"): not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0])) transforms.SanitizeBoundingBoxes()(not_a_dict) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 8b0eeec57c5..b3f939ebebf 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -233,24 +233,24 @@ class SanitizeBoundingBoxes(Transform): # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) _transformed_types = datapoints.BoundingBox - def __init__(self, min_size: float = 1.0, labels="default") -> None: + def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: super().__init__() if min_size < 1: raise ValueError(f"min_size must be >= 1, got {min_size}.") self.min_size = min_size - self.labels = labels - if labels == "default": - self._get_labels = self._find_label_default_heuristic - elif callable(self.labels): - self._get_labels = labels - elif isinstance(self.labels, str): - self._get_labels = lambda inputs: inputs[labels] + self.labels_getter = labels_getter + if labels_getter == "default": + self._labels_getter= self._find_label_default_heuristic + elif callable(labels_getter): + self._labels_getter = labels_getter + elif isinstance(labels_getter, str): + self._labels_getter= lambda inputs: inputs[labels_getter] else: raise ValueError( - "labels parameter should either be a str, callable, or 'default'. " - f"Got {labels} of type {type(labels)}." + "labels_getter should either be a str, callable, or 'default'. " + f"Got {labels_getter} of type {type(labels_getter)}." ) @staticmethod @@ -275,12 +275,12 @@ def _find_label_default_heuristic(inputs): def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] - if isinstance(self.labels, str) and not isinstance(inputs, collections.abc.Mapping): + if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): raise ValueError( - f"If labels is a str or 'default' (got {self.labels}), then the input to forward() must be a dict. " - f"Got {type(inputs)} instead" + f"If labels_getter is a str or 'default' (got {self.labels_getter}), " + f"then the input to forward() must be a dict. Got {type(inputs)} instead." ) - labels = self._get_labels(inputs) + labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") From e52ebb24afd0550a9e7bed20731343229b32e58a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 10:05:21 +0000 Subject: [PATCH 08/14] Add support for labels_getter=None --- test/test_prototype_transforms.py | 17 +++++++++++------ torchvision/prototype/transforms/_misc.py | 22 +++++++++++++--------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index c0ca50550ac..fc0177afeb2 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2359,7 +2359,9 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): @pytest.mark.parametrize("min_size", (1, 10)) -@pytest.mark.parametrize("labels_getter", ("default", "labels", lambda inputs: inputs["labels"])) +@pytest.mark.parametrize( + "labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) +) @pytest.mark.parametrize("batch", (True, False)) def test_sanitize_bounding_boxes(min_size, labels_getter, batch): H, W = 256, 128 @@ -2405,18 +2407,21 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, batch): "labels": labels, "boxes": boxes, "whatever": torch.rand(10), + "None": None, } out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) assert out["image"] is sample["image"] assert out["whatever"] is sample["whatever"] - assert isinstance(out["labels"], torch.Tensor) - assert out["boxes"].shape[:-1] == out["labels"].shape - - # This works because we conveniently set labels to arange(num_boxes) - assert out["labels"].tolist() == valid_indices + if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): + assert out["labels"] is sample["labels"] + else: + assert isinstance(out["labels"], torch.Tensor) + assert out["boxes"].shape[:-1] == out["labels"].shape + # This works because we conveniently set labels to arange(num_boxes) + assert out["labels"].tolist() == valid_indices @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index b3f939ebebf..f079f4c309b 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -240,13 +240,15 @@ def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: raise ValueError(f"min_size must be >= 1, got {min_size}.") self.min_size = min_size - self.labels_getter = labels_getter + self.labels_getter = labels_getter if labels_getter == "default": - self._labels_getter= self._find_label_default_heuristic + self._labels_getter = self._find_label_default_heuristic elif callable(labels_getter): - self._labels_getter = labels_getter + self._labels_getter = labels_getter elif isinstance(labels_getter, str): - self._labels_getter= lambda inputs: inputs[labels_getter] + self._labels_getter = lambda inputs: inputs[labels_getter] + elif labels_getter is None: + self._labels_getter = None else: raise ValueError( "labels_getter should either be a str, callable, or 'default'. " @@ -268,7 +270,8 @@ def _find_label_default_heuristic(inputs): if labels is None: raise ValueError( - "Could not infer where the labels are in the sample. Try passing a callable as the label parameter?" + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" + "If there are no samples and it is by design, pass labels_getter=None." ) return labels @@ -280,8 +283,9 @@ def forward(self, *inputs: Any) -> Any: f"If labels_getter is a str or 'default' (got {self.labels_getter}), " f"then the input to forward() must be a dict. Got {type(inputs)} instead." ) - labels = self._labels_getter(inputs) - if not isinstance(labels, torch.Tensor): + + labels = self._labels_getter(inputs) if self._labels_getter is not None else None + if labels is not None and not isinstance(labels, torch.Tensor): raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") flat_inputs, spec = tree_flatten(inputs) @@ -293,7 +297,7 @@ def forward(self, *inputs: Any) -> Any: if boxes.ndim > 3 or (boxes.ndim == 3 and boxes.shape[0] != 1): raise ValueError(f"boxes must be of shape (num_boxes, 4) or (1, num_boxes, 4), got {boxes.shape}") - if boxes.shape[:-1] != labels.shape: + if labels is not None and boxes.shape[:-1] != labels.shape: raise ValueError( f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) @@ -322,7 +326,7 @@ def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if not (inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox)): + if inpt is None or not (inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox)): return inpt out = inpt[params["mask"]] From 3a9619a53c8c5ada5277c96c94e18b733d321cc1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 10:12:13 +0000 Subject: [PATCH 09/14] minor test --- test/test_prototype_transforms.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index fc0177afeb2..dfed0f44232 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2430,6 +2430,12 @@ def test_sanitize_bounding_boxes_default_heuristic(key): d = {key: labels} assert transforms.SanitizeBoundingBoxes._find_label_default_heuristic(d) is labels + if key.lower() != "labels": + # If "labels" is in the dict (case-insensitive), + # it takes precedence over other keys which would otherwise be a match + d = {key: "something_else", "labels": labels} + assert transforms.SanitizeBoundingBoxes._find_label_default_heuristic(d) is labels + def test_sanitize_bounding_boxes_errors(): From b093987762d45d5a7030b72a6d61c1c36a1512aa Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 10:19:23 +0000 Subject: [PATCH 10/14] Naming nit --- test/test_prototype_transforms.py | 4 ++-- torchvision/prototype/transforms/_misc.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index dfed0f44232..4ef320ce03f 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2428,13 +2428,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, batch): def test_sanitize_bounding_boxes_default_heuristic(key): labels = torch.arange(10) d = {key: labels} - assert transforms.SanitizeBoundingBoxes._find_label_default_heuristic(d) is labels + assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels if key.lower() != "labels": # If "labels" is in the dict (case-insensitive), # it takes precedence over other keys which would otherwise be a match d = {key: "something_else", "labels": labels} - assert transforms.SanitizeBoundingBoxes._find_label_default_heuristic(d) is labels + assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels def test_sanitize_bounding_boxes_errors(): diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index f079f4c309b..6411c38e61c 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -242,7 +242,7 @@ def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: self.labels_getter = labels_getter if labels_getter == "default": - self._labels_getter = self._find_label_default_heuristic + self._labels_getter = self._find_labels_default_heuristic elif callable(labels_getter): self._labels_getter = labels_getter elif isinstance(labels_getter, str): @@ -256,7 +256,7 @@ def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: ) @staticmethod - def _find_label_default_heuristic(inputs): + def _find_labels_default_heuristic(inputs): # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive # Returns None if nothing is found labels = None From dbbebb73e946dc81aa012b08d6d04a05faa7e5dc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 11:29:34 +0000 Subject: [PATCH 11/14] Address comments --- torchvision/prototype/transforms/_misc.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 6411c38e61c..7ba17cbb5df 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -231,7 +231,6 @@ class SanitizeBoundingBoxes(Transform): # This removes boxes and their corresponding labels: # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) - _transformed_types = datapoints.BoundingBox def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: super().__init__() @@ -259,21 +258,18 @@ def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: def _find_labels_default_heuristic(inputs): # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive # Returns None if nothing is found - labels = None candidate_key = None with suppress(StopIteration): candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") if candidate_key is None: with suppress(StopIteration): candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) - labels = inputs.get(candidate_key) - - if labels is None: + if candidate_key is None: raise ValueError( "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" "If there are no samples and it is by design, pass labels_getter=None." ) - return labels + return inputs[candidate_key] def forward(self, *inputs: Any) -> Any: inputs = inputs if len(inputs) > 1 else inputs[0] @@ -284,9 +280,12 @@ def forward(self, *inputs: Any) -> Any: f"then the input to forward() must be a dict. Got {type(inputs)} instead." ) - labels = self._labels_getter(inputs) if self._labels_getter is not None else None - if labels is not None and not isinstance(labels, torch.Tensor): - raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") + if self._labels_getter is None: + labels = None + else: + labels = self._labels_getter(inputs) + if labels is not None and not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") flat_inputs, spec = tree_flatten(inputs) # TODO: this enforces one single BoundingBox entry. From 4621097b5272bc02fa64713976b1534902ac1f7d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 11:34:54 +0000 Subject: [PATCH 12/14] remove batch support --- test/test_prototype_transforms.py | 8 +------- torchvision/prototype/transforms/_misc.py | 23 +++++++++-------------- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 11627a9c7e5..a4fdfdfcf57 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2362,8 +2362,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): @pytest.mark.parametrize( "labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) ) -@pytest.mark.parametrize("batch", (True, False)) -def test_sanitize_bounding_boxes(min_size, labels_getter, batch): +def test_sanitize_bounding_boxes(min_size, labels_getter): H, W = 256, 128 boxes_and_validity = [ @@ -2391,11 +2390,6 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, batch): boxes = torch.tensor(boxes) labels = torch.arange(boxes.shape[-2]) - if batch: - boxes = boxes[None] - labels = labels[None] - valid_indices = [valid_indices] - boxes = datapoints.BoundingBox( boxes, format=datapoints.BoundingBoxFormat.XYXY, diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 7ba17cbb5df..fef52abb708 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -293,10 +293,10 @@ def forward(self, *inputs: Any) -> Any: # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? boxes = query_bounding_box(flat_inputs) - if boxes.ndim > 3 or (boxes.ndim == 3 and boxes.shape[0] != 1): - raise ValueError(f"boxes must be of shape (num_boxes, 4) or (1, num_boxes, 4), got {boxes.shape}") + if boxes.ndim != 2: + raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") - if labels is not None and boxes.shape[:-1] != labels.shape: + if labels is not None and boxes.shape[0] != labels.shape[0]: raise ValueError( f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) @@ -305,13 +305,13 @@ def forward(self, *inputs: Any) -> Any: boxes, new_format=datapoints.BoundingBoxFormat.XYXY, ) - ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1] + ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=-1) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = boxes.spatial_size - mask &= (boxes[..., 0] <= image_w) & (boxes[..., 2] <= image_w) - mask &= (boxes[..., 1] <= image_h) & (boxes[..., 3] <= image_h) + mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) + mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) params = dict(mask=mask, labels=labels) flat_outputs = [ @@ -325,12 +325,7 @@ def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if inpt is None or not (inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox)): - return inpt - - out = inpt[params["mask"]] - if inpt.ndim != out.ndim: - # Add extra batch dim - out = out[None] + if inpt is not None and (inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox)): + inpt = inpt[params["mask"]] - return out + return inpt From 85b96c408344805f74a3025875cddc30c1a0cb41 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Feb 2023 12:30:07 +0000 Subject: [PATCH 13/14] nits --- torchvision/prototype/transforms/_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index fef52abb708..f195dec1d74 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -325,7 +325,7 @@ def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if inpt is not None and (inpt is params["labels"] or isinstance(inpt, datapoints.BoundingBox)): + if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): inpt = inpt[params["mask"]] return inpt From 66dcfc88f37742db1b3e41251609c185c6244900 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 15 Feb 2023 13:48:53 +0100 Subject: [PATCH 14/14] mypy --- torchvision/prototype/transforms/_misc.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index f195dec1d74..caed3eec904 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,7 +1,7 @@ import collections import warnings from contextlib import suppress -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -232,7 +232,11 @@ class SanitizeBoundingBoxes(Transform): # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) - def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: + def __init__( + self, + min_size: float = 1.0, + labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", + ) -> None: super().__init__() if min_size < 1: @@ -240,6 +244,7 @@ def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: self.min_size = min_size self.labels_getter = labels_getter + self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] if labels_getter == "default": self._labels_getter = self._find_labels_default_heuristic elif callable(labels_getter): @@ -255,7 +260,7 @@ def __init__(self, min_size: float = 1.0, labels_getter="default") -> None: ) @staticmethod - def _find_labels_default_heuristic(inputs): + def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: # Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive # Returns None if nothing is found candidate_key = None @@ -301,12 +306,15 @@ def forward(self, *inputs: Any) -> Any: f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." ) - boxes = F.convert_format_bounding_box( - boxes, - new_format=datapoints.BoundingBoxFormat.XYXY, + boxes = cast( + datapoints.BoundingBox, + F.convert_format_bounding_box( + boxes, + new_format=datapoints.BoundingBoxFormat.XYXY, + ), ) ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(axis=-1) + mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = boxes.spatial_size