-
Notifications
You must be signed in to change notification settings - Fork 7.1k
make clamp_bounding_box a kernel / dispatcher hybrid #7227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
b639a2f
7572245
b8ce022
d2b0dde
3dffd17
7c68623
0213afb
40a1df1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from datasets_utils import combinations_grid | ||
from prototype_common_utils import ( | ||
ArgsKwargs, | ||
BoundingBoxLoader, | ||
get_num_channels, | ||
ImageLoader, | ||
InfoBase, | ||
|
@@ -22,6 +23,7 @@ | |
make_mask_loaders, | ||
make_video_loaders, | ||
mark_framework_limitation, | ||
TensorLoader, | ||
TestMark, | ||
) | ||
from torch.utils._pytree import tree_map | ||
|
@@ -1988,8 +1990,15 @@ def sample_inputs_adjust_saturation_video(): | |
|
||
def sample_inputs_clamp_bounding_box(): | ||
for bounding_box_loader in make_bounding_box_loaders(): | ||
yield ArgsKwargs(bounding_box_loader) | ||
|
||
simple_tensor_loader = TensorLoader( | ||
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor), | ||
shape=bounding_box_loader.shape, | ||
dtype=bounding_box_loader.dtype, | ||
) | ||
yield ArgsKwargs( | ||
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size | ||
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size | ||
) | ||
|
||
|
||
|
@@ -1998,6 +2007,17 @@ def sample_inputs_clamp_bounding_box(): | |
F.clamp_bounding_box, | ||
sample_inputs_fn=sample_inputs_clamp_bounding_box, | ||
logs_usage=True, | ||
test_marks=[ | ||
mark_framework_limitation( | ||
("TestKernels", "test_scripted_vs_eager"), | ||
reason=( | ||
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a " | ||
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor " | ||
"`spatial_size` was passed" | ||
), | ||
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this only applies to |
||
) | ||
], | ||
) | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,12 +155,14 @@ def _unbatch(self, batch, *, data_dims): | |
if batched_tensor.ndim == data_dims: | ||
return batch | ||
|
||
return [ | ||
self._unbatch(unbatched, data_dims=data_dims) | ||
for unbatched in ( | ||
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] | ||
) | ||
] | ||
unbatcheds = [] | ||
for unbatched in ( | ||
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] | ||
): | ||
if isinstance(batch, datapoints._datapoint.Datapoint): | ||
unbatched = type(batch).wrap_like(batch, unbatched) | ||
Comment on lines
+162
to
+163
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small fix since this didn't respect datapoints types before. This was not an issue, since this is called from a kernel test and so far all kernels operated only with plain tensors. Meaning, all datapoints would have been unwrapped anyway. |
||
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims)) | ||
return unbatcheds | ||
|
||
@sample_inputs | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quick and dirty. If we have more such cases in the future, we should have something like an
unwrap
method or the like to get the plain tensor.