1
- import enum
2
1
import functools
3
2
import pathlib
4
3
import re
10
9
IterKeyZipper ,
11
10
Mapper ,
12
11
Filter ,
13
- Demultiplexer ,
14
12
TarArchiveLoader ,
15
13
Enumerator ,
16
14
)
27
25
hint_shuffling ,
28
26
read_categories_file ,
29
27
path_accessor ,
28
+ path_comparator ,
30
29
)
31
30
from torchvision .prototype .features import Label , EncodedImage
32
31
@@ -46,9 +45,9 @@ def __init__(self, **kwargs: Any) -> None:
46
45
super ().__init__ ("Register on https://image-net.org/ and follow the instructions there." , ** kwargs )
47
46
48
47
49
- class ImageNetDemux (enum .IntEnum ):
50
- META = 0
51
- LABEL = 1
48
+ # class ImageNetDemux(enum.IntEnum):
49
+ # META = 0
50
+ # LABEL = 1
52
51
53
52
54
53
@register_dataset (NAME )
@@ -108,12 +107,6 @@ def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label,
108
107
def _prepare_test_data (self , data : Tuple [str , BinaryIO ]) -> Tuple [None , Tuple [str , BinaryIO ]]:
109
108
return None , data
110
109
111
- def _classifiy_devkit (self , data : Tuple [str , BinaryIO ]) -> Optional [int ]:
112
- return {
113
- "meta.mat" : ImageNetDemux .META ,
114
- "ILSVRC2012_validation_ground_truth.txt" : ImageNetDemux .LABEL ,
115
- }.get (pathlib .Path (data [0 ]).name )
116
-
117
110
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
118
111
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
119
112
_WNID_MAP = {
@@ -172,13 +165,11 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
172
165
else : # config.split == "val":
173
166
images_dp , devkit_dp = resource_dps
174
167
175
- meta_dp , label_dp = Demultiplexer (
176
- devkit_dp , 2 , self ._classifiy_devkit , drop_none = True , buffer_size = INFINITE_BUFFER_SIZE
177
- )
178
-
168
+ meta_dp = Filter (devkit_dp , path_comparator ("name" , "meta.mat" ))
179
169
meta_dp = Mapper (meta_dp , self ._extract_categories_and_wnids )
180
- _ , wnids = zip (* next ( iter ( meta_dp )) )
170
+ _ , wnids = zip (* list ( meta_dp )[ 0 ] )
181
171
172
+ label_dp = Filter (devkit_dp , path_comparator ("name" , "ILSVRC2012_validation_ground_truth.txt" ))
182
173
label_dp = LineReader (label_dp , decode = True , return_path = False )
183
174
# We cannot use self._wnids here, since we use a different order than the dataset
184
175
label_dp = Mapper (label_dp , functools .partial (self ._imagenet_label_to_wnid , wnids = wnids ))
@@ -204,15 +195,12 @@ def __len__(self) -> int:
204
195
"test" : 100_000 ,
205
196
}[self ._split ]
206
197
207
- def _filter_meta (self , data : Tuple [str , Any ]) -> bool :
208
- return self ._classifiy_devkit (data ) == ImageNetDemux .META
209
-
210
198
def _generate_categories (self ) -> List [Tuple [str , ...]]:
211
199
self ._split = "val"
212
200
resources = self ._resources ()
213
201
214
202
devkit_dp = resources [1 ].load (self ._root )
215
- meta_dp = Filter (devkit_dp , self . _filter_meta )
203
+ meta_dp = Filter (devkit_dp , path_comparator ( "name" , "meta.mat" ) )
216
204
meta_dp = Mapper (meta_dp , self ._extract_categories_and_wnids )
217
205
218
206
categories_and_wnids = cast (List [Tuple [str , ...]], next (iter (meta_dp )))
0 commit comments