Skip to content

add a vision_collate function that honours prototype features #6680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 111 additions & 62 deletions test/test_prototype_features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import torch
from prototype_common_utils import make_image, make_label, make_segmentation_mask
from torch.utils.data import DataLoader
from torchvision.prototype import features


Expand All @@ -10,104 +12,151 @@ def test_isinstance():
)


def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
class TestTorchFunction:
def test_wrapping_no_copy(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

assert label.data_ptr() == tensor.data_ptr()
assert label.data_ptr() == tensor.data_ptr()

def test_to_wrapping(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32)

label_to = label.to(torch.int32)
assert type(label_to) is features.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories

assert type(label_to) is features.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories
def test_to_feature_reference(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32)

tensor_to = tensor.to(label)

def test_to_feature_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
assert type(tensor_to) is torch.Tensor
assert tensor_to.dtype is torch.int32

tensor_to = tensor.to(label)
def test_clone_wrapping(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

assert type(tensor_to) is torch.Tensor
assert tensor_to.dtype is torch.int32
label_clone = label.clone()

assert type(label_clone) is features.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories

def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
def test_requires_grad__wrapping(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = features.Label(tensor, categories=["foo", "bar"])

label_clone = label.clone()
assert not label.requires_grad

assert type(label_clone) is features.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories
label_requires_grad = label.requires_grad_(True)

assert type(label_requires_grad) is features.Label
assert label.requires_grad
assert label_requires_grad.requires_grad

def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = features.Label(tensor, categories=["foo", "bar"])
def test_other_op_no_wrapping(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

# any operation besides .to() and .clone() will do here
output = label * 2

assert not label.requires_grad
assert type(output) is torch.Tensor

label_requires_grad = label.requires_grad_(True)
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(self, op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

output = op(label)

assert type(output) is not features.Label

def test_inplace_op_no_wrapping(self):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

output = label.add_(0)

assert type(label_requires_grad) is features.Label
assert label.requires_grad
assert label_requires_grad.requires_grad
assert type(output) is torch.Tensor
assert type(label) is features.Label


def test_other_op_no_wrapping():
def test_new_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])

# any operation besides .to() and .clone() will do here
output = label * 2

assert type(output) is torch.Tensor
label_new = features.Label.new_like(label, output)

assert type(label_new) is features.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories


class TestVisionCollate:
def check_collation(self, dataset, expected_batch, *, collate_fn=features.vision_collate):
data_loader = DataLoader(dataset, num_workers=0, batch_size=len(dataset), collate_fn=collate_fn)

@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
actual_batch = list(data_loader)[0]

output = op(label)
torch.testing.assert_close(actual_batch, expected_batch)

assert type(output) is not features.Label
return actual_batch

@pytest.mark.parametrize("with_labels", [True, False])
def test_classification(self, with_labels):
image_size = (16, 17)
categories = ["foo", "bar", "baz"]

def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
dataset = []
for _ in range(4):
image = make_image(size=image_size)
label = make_label(categories=categories) if with_labels else None

output = label.add_(0)
dataset.append((image, label))

assert type(output) is torch.Tensor
assert type(label) is features.Label
expected_images, expected_labels = zip(*dataset)
expected_batch = [
features.Image.new_like(expected_images[0], torch.stack(expected_images)),
features.Label.new_like(expected_labels[0], torch.stack(expected_labels))
if with_labels
else list(expected_labels),
]

actual_batch = self.check_collation(dataset, expected_batch)

def test_new_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
if with_labels:
assert actual_batch[1].categories == categories

# any operation besides .to() and .clone() will do here
output = label * 2
def test_segmentation(self):
image_size = (16, 17)

label_new = features.Label.new_like(label, output)
dataset = []
for _ in range(4):
image = make_image(size=image_size)
mask = make_segmentation_mask(size=image_size, num_categories=10)

assert type(label_new) is features.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories
dataset.append((image, mask))

expected_batch = [
type(expected_features[0]).new_like(expected_features[0], torch.stack(expected_features))
for expected_features in zip(*dataset)
]

self.check_collation(dataset, expected_batch)
2 changes: 2 additions & 0 deletions torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
)
from ._label import Label, OneHotLabel
from ._mask import Mask

from ._collate import vision_collate # usort: skip
29 changes: 29 additions & 0 deletions torchvision/prototype/features/_collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put this into features for now, put when everything is rolled out it should probably be in datasets.


from torch.utils.data._utils.collate import collate, collate_tensor_fn, default_collate_fn_map
from torchvision.prototype.features import Image, Label, Mask, OneHotLabel


def no_collate_fn(
batch: Sequence[Any], *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
) -> Any:
return batch


def new_like_collate_fn(
batch: Sequence[Any], *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None
) -> Any:
feature = batch[0]
tensor = collate_tensor_fn(batch, collate_fn_map=collate_fn_map)
return type(feature).new_like(feature, tensor)


vision_collate_fn_map = {
(Image, Mask, Label, OneHotLabel): new_like_collate_fn,
type(None): no_collate_fn,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this for #5233.

**default_collate_fn_map,
}


def vision_collate(batch: Sequence[Any]) -> Any:
return collate(batch, collate_fn_map=vision_collate_fn_map)