Skip to content

add tests for the output types of prototype functional dispatchers #7118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def xfail_jit_list_of_ints(name, *, reason=None):
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
)

multi_crop_skips = [
TestMark(
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
Expand Down Expand Up @@ -404,7 +413,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_datapoint,
*multi_crop_skips,
],
),
DispatcherInfo(
Expand All @@ -415,7 +424,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_datapoint,
*multi_crop_skips,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
),
Expand Down
37 changes: 37 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor

@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
Expand All @@ -381,6 +391,22 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_pil_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()

if image_datapoint.ndim > 3:
pytest.skip("Input is batched")

image_pil = F.to_image_pil(image_datapoint)

output = info.dispatcher(image_pil, *other_args, **kwargs)

assert isinstance(output, PIL.Image.Image)

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
Expand All @@ -397,6 +423,17 @@ def test_dispatch_datapoint(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_datapoint_output_type(self, info, args_kwargs):
(datapoint, *other_args), kwargs = args_kwargs.load()

output = info.dispatcher(datapoint, *other_args, **kwargs)

assert isinstance(output, type(datapoint))

@pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"),
[
Expand Down