Skip to content

Commit 24fd7d6

Browse files
committed
Fixed wrong imports
1 parent eafeb5b commit 24fd7d6

File tree

4 files changed

+46
-44
lines changed

4 files changed

+46
-44
lines changed

test/datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,8 @@ def test_transforms(self, config):
584584

585585
@test_all_configs
586586
def test_transforms_v2_wrapper(self, config):
587-
from torchvision.datapoints import wrap_dataset_for_transforms_v2
588587
from torchvision.datapoints._datapoint import Datapoint
588+
from torchvision.datasets import wrap_dataset_for_transforms_v2
589589

590590
try:
591591
with self.create_dataset(config) as (dataset, _):

test/test_datasets.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pathlib
99
import pickle
1010
import random
11+
import re
1112
import shutil
1213
import string
1314
import unittest
@@ -3309,5 +3310,47 @@ def test_bad_input(self):
33093310
pass
33103311

33113312

3313+
class TestDatasetWrapper:
3314+
def test_unknown_type(self):
3315+
unknown_object = object()
3316+
with pytest.raises(
3317+
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
3318+
):
3319+
datasets.wrap_dataset_for_transforms_v2(unknown_object)
3320+
3321+
def test_unknown_dataset(self):
3322+
class MyVisionDataset(datasets.VisionDataset):
3323+
pass
3324+
3325+
dataset = MyVisionDataset("root")
3326+
3327+
with pytest.raises(TypeError, match="No wrapper exist"):
3328+
datasets.wrap_dataset_for_transforms_v2(dataset)
3329+
3330+
def test_missing_wrapper(self):
3331+
dataset = datasets.FakeData()
3332+
3333+
with pytest.raises(TypeError, match="please open an issue"):
3334+
datasets.wrap_dataset_for_transforms_v2(dataset)
3335+
3336+
def test_subclass(self, mocker):
3337+
from torchvision import datapoints
3338+
3339+
sentinel = object()
3340+
mocker.patch.dict(
3341+
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
3342+
clear=False,
3343+
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
3344+
)
3345+
3346+
class MyFakeData(datasets.FakeData):
3347+
pass
3348+
3349+
dataset = MyFakeData()
3350+
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
3351+
3352+
assert wrapped_dataset[0] is sentinel
3353+
3354+
33123355
if __name__ == "__main__":
33133356
unittest.main()

test/test_prototype_datapoints.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from PIL import Image
77

8-
from torchvision import datapoints, datasets
98
from torchvision.prototype import datapoints as proto_datapoints
109

1110

@@ -163,43 +162,3 @@ def test_bbox_instance(data, format):
163162
if isinstance(format, str):
164163
format = datapoints.BoundingBoxFormat.from_str(format.upper())
165164
assert bboxes.format == format
166-
167-
168-
class TestDatasetWrapper:
169-
def test_unknown_type(self):
170-
unknown_object = object()
171-
with pytest.raises(
172-
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
173-
):
174-
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
175-
176-
def test_unknown_dataset(self):
177-
class MyVisionDataset(datasets.VisionDataset):
178-
pass
179-
180-
dataset = MyVisionDataset("root")
181-
182-
with pytest.raises(TypeError, match="No wrapper exist"):
183-
datapoints.wrap_dataset_for_transforms_v2(dataset)
184-
185-
def test_missing_wrapper(self):
186-
dataset = datasets.FakeData()
187-
188-
with pytest.raises(TypeError, match="please open an issue"):
189-
datapoints.wrap_dataset_for_transforms_v2(dataset)
190-
191-
def test_subclass(self, mocker):
192-
sentinel = object()
193-
mocker.patch.dict(
194-
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
195-
clear=False,
196-
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
197-
)
198-
199-
class MyFakeData(datasets.FakeData):
200-
pass
201-
202-
dataset = MyFakeData()
203-
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
204-
205-
assert wrapped_dataset[0] is sentinel

torchvision/datapoints/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
2+
13
from ._bounding_box import BoundingBox, BoundingBoxFormat
24
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
35
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
46
from ._mask import Mask
57
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
68

7-
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
8-
99
if _WARN_ABOUT_BETA_TRANSFORMS:
1010
import warnings
1111

0 commit comments

Comments
 (0)