Skip to content

Commit 8dbae71

Browse files
NicolasHugcyyever
authored andcommitted
Add FlyingThings3D dataset for optical flow (pytorch#4858)
1 parent 200e8f1 commit 8dbae71

File tree

5 files changed

+198
-8
lines changed

5 files changed

+198
-8
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4444
Flickr8k
4545
Flickr30k
4646
FlyingChairs
47+
FlyingThings3D
4748
HMDB51
4849
ImageNet
4950
INaturalist

test/datasets_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase):
204204
``transforms``, or ``download``.
205205
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
206206
available, the tests are skipped.
207-
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
208207
209208
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
210209
The fake data should resemble the original data as close as necessary, while containing only few examples. During
@@ -256,8 +255,6 @@ def test_baz(self):
256255
ADDITIONAL_CONFIGS = None
257256
REQUIRED_PACKAGES = None
258257

259-
EXTRA_PATCHES = None
260-
261258
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
262259
_TRANSFORM_KWARGS = {
263260
"transform",
@@ -383,17 +380,14 @@ def create_dataset(
383380
if patch_checks:
384381
patchers.update(self._patch_checks())
385382

386-
if self.EXTRA_PATCHES is not None:
387-
patchers.update(self.EXTRA_PATCHES)
388-
389383
with get_tmp_dir() as tmpdir:
390384
args = self.dataset_args(tmpdir, complete_config)
391385
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
392386

393387
with self._maybe_apply_patches(patchers), disable_console_output():
394388
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
395389

396-
yield dataset, info
390+
yield dataset, info
397391

398392
@classmethod
399393
def setUpClass(cls):
@@ -925,6 +919,14 @@ def create_random_string(length: int, *digits: str) -> str:
925919
return "".join(random.choice(digits) for _ in range(length))
926920

927921

922+
def make_fake_pfm_file(h, w, file_name):
923+
values = list(range(3 * h * w))
924+
# Note: we pack everything in little endian: -1.0, and "<"
925+
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
926+
with open(file_name, "wb") as f:
927+
f.write(content)
928+
929+
928930
def make_fake_flo_file(h, w, file_name):
929931
"""Creates a fake flow file in .flo format."""
930932
values = list(range(2 * h * w))

test/test_datasets.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,5 +2048,72 @@ def test_flow(self, config):
20482048
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
20492049

20502050

2051+
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
2052+
DATASET_CLASS = datasets.FlyingThings3D
2053+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2054+
split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both")
2055+
)
2056+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2057+
2058+
FLOW_H, FLOW_W = 3, 4
2059+
2060+
def inject_fake_data(self, tmpdir, config):
2061+
root = pathlib.Path(tmpdir) / "FlyingThings3D"
2062+
2063+
num_images_per_camera = 3 if config["split"] == "train" else 4
2064+
passes = ("frames_cleanpass", "frames_finalpass")
2065+
splits = ("TRAIN", "TEST")
2066+
letters = ("A", "B", "C")
2067+
subfolders = ("0000", "0001")
2068+
cameras = ("left", "right")
2069+
for pass_name, split, letter, subfolder, camera in itertools.product(
2070+
passes, splits, letters, subfolders, cameras
2071+
):
2072+
current_folder = root / pass_name / split / letter / subfolder
2073+
datasets_utils.create_image_folder(
2074+
current_folder,
2075+
name=camera,
2076+
file_name_fn=lambda image_idx: f"00{image_idx}.png",
2077+
num_examples=num_images_per_camera,
2078+
)
2079+
2080+
directions = ("into_future", "into_past")
2081+
for split, letter, subfolder, direction, camera in itertools.product(
2082+
splits, letters, subfolders, directions, cameras
2083+
):
2084+
current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera
2085+
os.makedirs(str(current_folder), exist_ok=True)
2086+
for i in range(num_images_per_camera):
2087+
datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm"))
2088+
2089+
num_cameras = 2 if config["camera"] == "both" else 1
2090+
num_passes = 2 if config["pass_name"] == "both" else 1
2091+
num_examples = (
2092+
(num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes
2093+
)
2094+
return num_examples
2095+
2096+
@datasets_utils.test_all_configs
2097+
def test_flow(self, config):
2098+
with self.create_dataset(config=config) as (dataset, _):
2099+
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
2100+
for _, _, flow in dataset:
2101+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2102+
# We don't check the values because the reshaping and flipping makes it hard to figure out
2103+
2104+
def test_bad_input(self):
2105+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
2106+
with self.create_dataset(split="bad"):
2107+
pass
2108+
2109+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
2110+
with self.create_dataset(pass_name="bad"):
2111+
pass
2112+
2113+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument camera"):
2114+
with self.create_dataset(camera="bad"):
2115+
pass
2116+
2117+
20512118
if __name__ == "__main__":
20522119
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._optical_flow import KittiFlow, Sintel, FlyingChairs
1+
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
22
from .caltech import Caltech101, Caltech256
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
@@ -75,4 +75,5 @@
7575
"KittiFlow",
7676
"Sintel",
7777
"FlyingChairs",
78+
"FlyingThings3D",
7879
)

torchvision/datasets/_optical_flow.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import itertools
12
import os
3+
import re
24
from abc import ABC, abstractmethod
35
from glob import glob
46
from pathlib import Path
@@ -15,6 +17,7 @@
1517
__all__ = (
1618
"KittiFlow",
1719
"Sintel",
20+
"FlyingThings3D",
1821
"FlyingChairs",
1922
)
2023

@@ -271,6 +274,94 @@ def _read_flow(self, file_name):
271274
return _read_flo(file_name)
272275

273276

277+
class FlyingThings3D(FlowDataset):
278+
"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
279+
280+
The dataset is expected to have the following structure: ::
281+
282+
root
283+
FlyingThings3D
284+
frames_cleanpass
285+
TEST
286+
TRAIN
287+
frames_finalpass
288+
TEST
289+
TRAIN
290+
optical_flow
291+
TEST
292+
TRAIN
293+
294+
Args:
295+
root (string): Root directory of the intel FlyingThings3D Dataset.
296+
split (string, optional): The dataset split, either "train" (default) or "test"
297+
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
298+
details on the different passes.
299+
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
300+
transforms (callable, optional): A function/transform that takes in
301+
``img1, img2, flow, valid`` and returns a transformed version.
302+
``valid`` is expected for consistency with other datasets which
303+
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
304+
"""
305+
306+
def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None):
307+
super().__init__(root=root, transforms=transforms)
308+
309+
verify_str_arg(split, "split", valid_values=("train", "test"))
310+
split = split.upper()
311+
312+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
313+
passes = {
314+
"clean": ["frames_cleanpass"],
315+
"final": ["frames_finalpass"],
316+
"both": ["frames_cleanpass", "frames_finalpass"],
317+
}[pass_name]
318+
319+
verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
320+
cameras = ["left", "right"] if camera == "both" else [camera]
321+
322+
root = Path(root) / "FlyingThings3D"
323+
324+
directions = ("into_future", "into_past")
325+
for pass_name, camera, direction in itertools.product(passes, cameras, directions):
326+
image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
327+
image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs])
328+
329+
flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
330+
flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs])
331+
332+
if not image_dirs or not flow_dirs:
333+
raise FileNotFoundError(
334+
"Could not find the FlyingThings3D flow images. "
335+
"Please make sure the directory structure is correct."
336+
)
337+
338+
for image_dir, flow_dir in zip(image_dirs, flow_dirs):
339+
images = sorted(glob(str(image_dir / "*.png")))
340+
flows = sorted(glob(str(flow_dir / "*.pfm")))
341+
for i in range(len(flows) - 1):
342+
if direction == "into_future":
343+
self._image_list += [[images[i], images[i + 1]]]
344+
self._flow_list += [flows[i]]
345+
elif direction == "into_past":
346+
self._image_list += [[images[i + 1], images[i]]]
347+
self._flow_list += [flows[i + 1]]
348+
349+
def __getitem__(self, index):
350+
"""Return example at given index.
351+
352+
Args:
353+
index(int): The index of the example to retrieve
354+
355+
Returns:
356+
tuple: A 3-tuple with ``(img1, img2, flow)``.
357+
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
358+
"""
359+
return super().__getitem__(index)
360+
361+
def _read_flow(self, file_name):
362+
return _read_pfm(file_name)
363+
364+
274365
def _read_flo(file_name):
275366
"""Read .flo file in Middlebury format"""
276367
# Code adapted from:
@@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
295386

296387
# For consistency with other datasets, we convert to numpy
297388
return flow.numpy(), valid.numpy()
389+
390+
391+
def _read_pfm(file_name):
392+
"""Read flow in .pfm format"""
393+
394+
with open(file_name, "rb") as f:
395+
header = f.readline().rstrip()
396+
if header != b"PF":
397+
raise ValueError("Invalid PFM file")
398+
399+
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
400+
if not dim_match:
401+
raise Exception("Malformed PFM header.")
402+
w, h = (int(dim) for dim in dim_match.groups())
403+
404+
scale = float(f.readline().rstrip())
405+
if scale < 0: # little-endian
406+
endian = "<"
407+
scale = -scale
408+
else:
409+
endian = ">" # big-endian
410+
411+
data = np.fromfile(f, dtype=endian + "f")
412+
413+
data = data.reshape(h, w, 3).transpose(2, 0, 1)
414+
data = np.flip(data, axis=1) # flip on h dimension
415+
data = data[:2, :, :]
416+
return data.astype(np.float32)

0 commit comments

Comments
 (0)