-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add SanitizeBoundingBoxes transform #7246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
7c5ab88
Add SanitizeBoundingBoxes transform
NicolasHug 26929c0
Added basic test, will improve upon
NicolasHug 9ae43b2
Address comments + add tests
NicolasHug 2ac342a
A little more
NicolasHug 87a849e
Add more input check tests
NicolasHug 7839dd8
Add support for batch dimension - only 1 batch tho
NicolasHug 57aaab3
renamed labels parameter into labels_getter
NicolasHug e52ebb2
Add support for labels_getter=None
NicolasHug 3a9619a
minor test
NicolasHug b093987
Naming nit
NicolasHug 96ded4c
Merge branch 'main' of github.com:pytorch/vision into sanitize_boundi…
NicolasHug dbbebb7
Address comments
NicolasHug 4621097
remove batch support
NicolasHug ed030a5
Merge branch 'main' of github.com:pytorch/vision into sanitize_boundi…
NicolasHug 85b96c4
nits
NicolasHug 66dcfc8
mypy
pmeier dda8810
Merge branch 'main' into sanitize_boundingboxes
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
import collections | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union | ||
from contextlib import suppress | ||
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union | ||
|
||
import PIL.Image | ||
|
||
import torch | ||
from torch.utils._pytree import tree_flatten, tree_unflatten | ||
|
||
from torchvision import transforms as _transforms | ||
from torchvision.ops import remove_small_boxes | ||
from torchvision.prototype import datapoints | ||
from torchvision.prototype.transforms import functional as F, Transform | ||
|
||
|
@@ -225,28 +227,113 @@ def _transform( | |
return inpt.transpose(*dims) | ||
|
||
|
||
class RemoveSmallBoundingBoxes(Transform): | ||
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel) | ||
class SanitizeBoundingBoxes(Transform): | ||
# This removes boxes and their corresponding labels: | ||
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) | ||
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size) | ||
|
||
def __init__(self, min_size: float = 1.0) -> None: | ||
def __init__( | ||
self, | ||
min_size: float = 1.0, | ||
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", | ||
) -> None: | ||
super().__init__() | ||
|
||
if min_size < 1: | ||
raise ValueError(f"min_size must be >= 1, got {min_size}.") | ||
self.min_size = min_size | ||
|
||
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
bounding_box = query_bounding_box(flat_inputs) | ||
|
||
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to | ||
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH | ||
# format,we need to convert first just to afterwards compute the width and height again, although they were | ||
# there in the first place for these formats. | ||
bounding_box = F.convert_format_bounding_box( | ||
bounding_box.as_subclass(torch.Tensor), | ||
old_format=bounding_box.format, | ||
new_format=datapoints.BoundingBoxFormat.XYXY, | ||
) | ||
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) | ||
self.labels_getter = labels_getter | ||
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] | ||
if labels_getter == "default": | ||
self._labels_getter = self._find_labels_default_heuristic | ||
elif callable(labels_getter): | ||
self._labels_getter = labels_getter | ||
elif isinstance(labels_getter, str): | ||
self._labels_getter = lambda inputs: inputs[labels_getter] | ||
elif labels_getter is None: | ||
self._labels_getter = None | ||
else: | ||
raise ValueError( | ||
"labels_getter should either be a str, callable, or 'default'. " | ||
f"Got {labels_getter} of type {type(labels_getter)}." | ||
) | ||
|
||
@staticmethod | ||
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: | ||
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive | ||
# Returns None if nothing is found | ||
candidate_key = None | ||
with suppress(StopIteration): | ||
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") | ||
if candidate_key is None: | ||
with suppress(StopIteration): | ||
candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) | ||
if candidate_key is None: | ||
raise ValueError( | ||
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" | ||
"If there are no samples and it is by design, pass labels_getter=None." | ||
) | ||
return inputs[candidate_key] | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
inputs = inputs if len(inputs) > 1 else inputs[0] | ||
|
||
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping): | ||
raise ValueError( | ||
f"If labels_getter is a str or 'default' (got {self.labels_getter}), " | ||
f"then the input to forward() must be a dict. Got {type(inputs)} instead." | ||
) | ||
|
||
if self._labels_getter is None: | ||
labels = None | ||
else: | ||
labels = self._labels_getter(inputs) | ||
if labels is not None and not isinstance(labels, torch.Tensor): | ||
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.") | ||
|
||
return dict(valid_indices=valid_indices) | ||
flat_inputs, spec = tree_flatten(inputs) | ||
# TODO: this enforces one single BoundingBox entry. | ||
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... | ||
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? | ||
boxes = query_bounding_box(flat_inputs) | ||
|
||
if boxes.ndim != 2: | ||
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") | ||
|
||
if labels is not None and boxes.shape[0] != labels.shape[0]: | ||
raise ValueError( | ||
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." | ||
) | ||
|
||
boxes = cast( | ||
datapoints.BoundingBox, | ||
F.convert_format_bounding_box( | ||
boxes, | ||
new_format=datapoints.BoundingBoxFormat.XYXY, | ||
), | ||
) | ||
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) | ||
# TODO: Do we really need to check for out of bounds here? All | ||
# transforms should be clamping anyway, so this should never happen? | ||
Comment on lines
+318
to
+319
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep this for now until we are sure about this, i.e. we have tests that guarantee this. Happy to remove if it turns out we don't need it. |
||
image_h, image_w = boxes.spatial_size | ||
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) | ||
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) | ||
|
||
params = dict(mask=mask, labels=labels) | ||
flat_outputs = [ | ||
# Even-though it may look like we're transforming all inputs, we don't: | ||
# _transform() will only care about BoundingBoxes and the labels | ||
self._transform(inpt, params) | ||
for inpt in flat_inputs | ||
] | ||
|
||
return tree_unflatten(flat_outputs, spec) | ||
|
||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | ||
return inpt.wrap_like(inpt, inpt[params["valid_indices"]]) | ||
|
||
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): | ||
inpt = inpt[params["mask"]] | ||
|
||
return inpt |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.