Skip to content

Commit cb8b43d

Browse files
committed
fix ImageNet
1 parent e0734cd commit cb8b43d

File tree

2 files changed

+9
-21
lines changed

2 files changed

+9
-21
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_coverage():
4242
)
4343

4444

45-
@pytest.mark.filterwarnings("error")
45+
# @pytest.mark.filterwarnings("error")
4646
class TestCommon:
4747
@pytest.mark.parametrize("name", datasets.list_datasets())
4848
def test_info(self, name):

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import enum
21
import functools
32
import pathlib
43
import re
@@ -10,7 +9,6 @@
109
IterKeyZipper,
1110
Mapper,
1211
Filter,
13-
Demultiplexer,
1412
TarArchiveLoader,
1513
Enumerator,
1614
)
@@ -27,6 +25,7 @@
2725
hint_shuffling,
2826
read_categories_file,
2927
path_accessor,
28+
path_comparator,
3029
)
3130
from torchvision.prototype.features import Label, EncodedImage
3231

@@ -46,9 +45,9 @@ def __init__(self, **kwargs: Any) -> None:
4645
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
4746

4847

49-
class ImageNetDemux(enum.IntEnum):
50-
META = 0
51-
LABEL = 1
48+
# class ImageNetDemux(enum.IntEnum):
49+
# META = 0
50+
# LABEL = 1
5251

5352

5453
@register_dataset(NAME)
@@ -108,12 +107,6 @@ def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label,
108107
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
109108
return None, data
110109

111-
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
112-
return {
113-
"meta.mat": ImageNetDemux.META,
114-
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
115-
}.get(pathlib.Path(data[0]).name)
116-
117110
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
118111
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
119112
_WNID_MAP = {
@@ -172,13 +165,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
172165
else: # config.split == "val":
173166
images_dp, devkit_dp = resource_dps
174167

175-
meta_dp, label_dp = Demultiplexer(
176-
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
177-
)
178-
168+
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
179169
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
180-
_, wnids = zip(*next(iter(meta_dp)))
170+
_, wnids = zip(*list(meta_dp)[0])
181171

172+
label_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
182173
label_dp = LineReader(label_dp, decode=True, return_path=False)
183174
# We cannot use self._wnids here, since we use a different order than the dataset
184175
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
@@ -204,15 +195,12 @@ def __len__(self) -> int:
204195
"test": 100_000,
205196
}[self._split]
206197

207-
def _filter_meta(self, data: Tuple[str, Any]) -> bool:
208-
return self._classifiy_devkit(data) == ImageNetDemux.META
209-
210198
def _generate_categories(self) -> List[Tuple[str, ...]]:
211199
self._split = "val"
212200
resources = self._resources()
213201

214202
devkit_dp = resources[1].load(self._root)
215-
meta_dp = Filter(devkit_dp, self._filter_meta)
203+
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
216204
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
217205

218206
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))

0 commit comments

Comments
 (0)