diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index 2701dd66be0..5a8c64bb73c 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -1,5 +1,7 @@ import pytest import torch +from prototype_common_utils import make_image, make_label, make_segmentation_mask +from torch.utils.data import DataLoader from torchvision.prototype import features @@ -10,104 +12,151 @@ def test_isinstance(): ) -def test_wrapping_no_copy(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) +class TestTorchFunction: + def test_wrapping_no_copy(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) - assert label.data_ptr() == tensor.data_ptr() + assert label.data_ptr() == tensor.data_ptr() + def test_to_wrapping(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) -def test_to_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + label_to = label.to(torch.int32) - label_to = label.to(torch.int32) + assert type(label_to) is features.Label + assert label_to.dtype is torch.int32 + assert label_to.categories is label.categories - assert type(label_to) is features.Label - assert label_to.dtype is torch.int32 - assert label_to.categories is label.categories + def test_to_feature_reference(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) + tensor_to = tensor.to(label) -def test_to_feature_reference(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) + assert type(tensor_to) is torch.Tensor + assert tensor_to.dtype is torch.int32 - tensor_to = tensor.to(label) + def test_clone_wrapping(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) - assert type(tensor_to) is torch.Tensor - assert tensor_to.dtype is torch.int32 + label_clone = label.clone() + assert type(label_clone) is features.Label + assert label_clone.data_ptr() != label.data_ptr() + assert label_clone.categories is label.categories -def test_clone_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + def test_requires_grad__wrapping(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.float32) + label = features.Label(tensor, categories=["foo", "bar"]) - label_clone = label.clone() + assert not label.requires_grad - assert type(label_clone) is features.Label - assert label_clone.data_ptr() != label.data_ptr() - assert label_clone.categories is label.categories + label_requires_grad = label.requires_grad_(True) + assert type(label_requires_grad) is features.Label + assert label.requires_grad + assert label_requires_grad.requires_grad -def test_requires_grad__wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.float32) - label = features.Label(tensor, categories=["foo", "bar"]) + def test_other_op_no_wrapping(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + # any operation besides .to() and .clone() will do here + output = label * 2 - assert not label.requires_grad + assert type(output) is torch.Tensor - label_requires_grad = label.requires_grad_(True) + @pytest.mark.parametrize( + "op", + [ + lambda t: t.numpy(), + lambda t: t.tolist(), + lambda t: t.max(dim=-1), + ], + ) + def test_no_tensor_output_op_no_wrapping(self, op): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + output = op(label) + + assert type(output) is not features.Label + + def test_inplace_op_no_wrapping(self): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + output = label.add_(0) - assert type(label_requires_grad) is features.Label - assert label.requires_grad - assert label_requires_grad.requires_grad + assert type(output) is torch.Tensor + assert type(label) is features.Label -def test_other_op_no_wrapping(): +def test_new_like(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) label = features.Label(tensor, categories=["foo", "bar"]) # any operation besides .to() and .clone() will do here output = label * 2 - assert type(output) is torch.Tensor + label_new = features.Label.new_like(label, output) + + assert type(label_new) is features.Label + assert label_new.data_ptr() == output.data_ptr() + assert label_new.categories is label.categories + +class TestVisionCollate: + def check_collation(self, dataset, expected_batch, *, collate_fn=features.vision_collate): + data_loader = DataLoader(dataset, num_workers=0, batch_size=len(dataset), collate_fn=collate_fn) -@pytest.mark.parametrize( - "op", - [ - lambda t: t.numpy(), - lambda t: t.tolist(), - lambda t: t.max(dim=-1), - ], -) -def test_no_tensor_output_op_no_wrapping(op): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + actual_batch = list(data_loader)[0] - output = op(label) + torch.testing.assert_close(actual_batch, expected_batch) - assert type(output) is not features.Label + return actual_batch + @pytest.mark.parametrize("with_labels", [True, False]) + def test_classification(self, with_labels): + image_size = (16, 17) + categories = ["foo", "bar", "baz"] -def test_inplace_op_no_wrapping(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + dataset = [] + for _ in range(4): + image = make_image(size=image_size) + label = make_label(categories=categories) if with_labels else None - output = label.add_(0) + dataset.append((image, label)) - assert type(output) is torch.Tensor - assert type(label) is features.Label + expected_images, expected_labels = zip(*dataset) + expected_batch = [ + features.Image.new_like(expected_images[0], torch.stack(expected_images)), + features.Label.new_like(expected_labels[0], torch.stack(expected_labels)) + if with_labels + else list(expected_labels), + ] + actual_batch = self.check_collation(dataset, expected_batch) -def test_new_like(): - tensor = torch.tensor([0, 1, 0], dtype=torch.int64) - label = features.Label(tensor, categories=["foo", "bar"]) + if with_labels: + assert actual_batch[1].categories == categories - # any operation besides .to() and .clone() will do here - output = label * 2 + def test_segmentation(self): + image_size = (16, 17) - label_new = features.Label.new_like(label, output) + dataset = [] + for _ in range(4): + image = make_image(size=image_size) + mask = make_segmentation_mask(size=image_size, num_categories=10) - assert type(label_new) is features.Label - assert label_new.data_ptr() == output.data_ptr() - assert label_new.categories is label.categories + dataset.append((image, mask)) + + expected_batch = [ + type(expected_features[0]).new_like(expected_features[0], torch.stack(expected_features)) + for expected_features in zip(*dataset) + ] + + self.check_collation(dataset, expected_batch) diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index df77e8b77b3..4b1b208c3f6 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -13,3 +13,5 @@ ) from ._label import Label, OneHotLabel from ._mask import Mask + +from ._collate import vision_collate # usort: skip diff --git a/torchvision/prototype/features/_collate.py b/torchvision/prototype/features/_collate.py new file mode 100644 index 00000000000..1b9816432df --- /dev/null +++ b/torchvision/prototype/features/_collate.py @@ -0,0 +1,29 @@ +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union + +from torch.utils.data._utils.collate import collate, collate_tensor_fn, default_collate_fn_map +from torchvision.prototype.features import Image, Label, Mask, OneHotLabel + + +def no_collate_fn( + batch: Sequence[Any], *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None +) -> Any: + return batch + + +def new_like_collate_fn( + batch: Sequence[Any], *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None +) -> Any: + feature = batch[0] + tensor = collate_tensor_fn(batch, collate_fn_map=collate_fn_map) + return type(feature).new_like(feature, tensor) + + +vision_collate_fn_map = { + (Image, Mask, Label, OneHotLabel): new_like_collate_fn, + type(None): no_collate_fn, + **default_collate_fn_map, +} + + +def vision_collate(batch: Sequence[Any]) -> Any: + return collate(batch, collate_fn_map=vision_collate_fn_map)