diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 89dfe7e08d8..ae5cf71a95d 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes str: digits = "".join(itertools.chain(*digits)) return "".join(random.choice(digits) for _ in range(length)) + + +def make_fake_flo_file(h, w, file_name): + """Creates a fake flow file in .flo format.""" + values = list(range(2 * h * w)) + content = b"PIEH" + struct.pack("i", w) + struct.pack("i", h) + struct.pack("f" * len(values), *values) + with open(file_name, "wb") as f: + f.write(content) diff --git a/test/test_datasets.py b/test/test_datasets.py index 57c2a80181a..e355cfc5b40 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1874,11 +1874,9 @@ def _inject_pairs(self, root, num_pairs, same): class SintelTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Sintel ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final")) - # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files, - # which is something we want to # avoid. - _FAKE_FLOW = "Fake Flow" - EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)} - FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None))) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "Sintel" @@ -1899,14 +1897,13 @@ def inject_fake_data(self, tmpdir, config): num_examples=num_images_per_scene, ) - # For the ground truth flow value we just create empty files so that they're properly discovered, - # see comment above about EXTRA_PATCHES flow_root = root / "training" / "flow" for scene_id in range(num_scenes): scene_dir = flow_root / f"scene_{scene_id}" os.makedirs(scene_dir) for i in range(num_images_per_scene - 1): - open(str(scene_dir / f"frame_000{i}.flo"), "a").close() + file_name = str(scene_dir / f"frame_000{i}.flo") + datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name) # with e.g. num_images_per_scene = 3, for a single scene with have 3 images # which are frame_0000, frame_0001 and frame_0002 @@ -1920,7 +1917,8 @@ def test_flow(self): with self.create_dataset(split="train") as (dataset, _): assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) for _, _, flow in dataset: - assert flow == self._FAKE_FLOW + assert flow.shape == (2, self.FLOW_H, self.FLOW_W) + np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape)) # Make sure flow is always None for test split with self.create_dataset(split="test") as (dataset, _): @@ -1929,11 +1927,11 @@ def test_flow(self): assert flow is None def test_bad_input(self): - with pytest.raises(ValueError, match="split must be either"): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): with self.create_dataset(split="bad"): pass - with pytest.raises(ValueError, match="pass_name must be either"): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"): with self.create_dataset(pass_name="bad"): pass @@ -1993,10 +1991,62 @@ def test_flow_and_valid(self): assert valid is None def test_bad_input(self): - with pytest.raises(ValueError, match="split must be either"): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): with self.create_dataset(split="bad"): pass +class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FlyingChairs + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 + + def _make_split_file(self, root, num_examples): + # We create a fake split file here, but users are asked to download the real one from the authors website + split_ids = [1] * num_examples["train"] + [2] * num_examples["val"] + random.shuffle(split_ids) + with open(str(root / "FlyingChairs_train_val.txt"), "w+") as split_file: + for split_id in split_ids: + split_file.write(f"{split_id}\n") + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "FlyingChairs" + + num_examples = {"train": 5, "val": 3} + num_examples_total = sum(num_examples.values()) + + datasets_utils.create_image_folder( # img1 + root, + name="data", + file_name_fn=lambda image_idx: f"00{image_idx}_img1.ppm", + num_examples=num_examples_total, + ) + datasets_utils.create_image_folder( # img2 + root, + name="data", + file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm", + num_examples=num_examples_total, + ) + for i in range(num_examples_total): + file_name = str(root / "data" / f"00{i}_flow.flo") + datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name) + + self._make_split_file(root, num_examples) + + return num_examples[config["split"]] + + @datasets_utils.test_all_configs + def test_flow(self, config): + # Make sure flow always exists, and make sure there are as many flow values as (pairs of) images + # Also make sure the flow is properly decoded + with self.create_dataset(config=config) as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow.shape == (2, self.FLOW_H, self.FLOW_W) + np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape)) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 5edcd1bc584..dfad4770a93 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,4 @@ -from ._optical_flow import KittiFlow, Sintel +from ._optical_flow import KittiFlow, Sintel, FlyingChairs from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -74,4 +74,5 @@ "LFWPairs", "KittiFlow", "Sintel", + "FlyingChairs", ) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 7cb19e8d8c4..f26127039d1 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -8,12 +8,14 @@ from PIL import Image from ..io.image import _read_png_16 +from .utils import verify_str_arg from .vision import VisionDataset __all__ = ( "KittiFlow", "Sintel", + "FlyingChairs", ) @@ -109,11 +111,8 @@ class Sintel(FlowDataset): def __init__(self, root, split="train", pass_name="clean", transforms=None): super().__init__(root=root, transforms=transforms) - if split not in ("train", "test"): - raise ValueError("split must be either 'train' or 'test'") - - if pass_name not in ("clean", "final"): - raise ValueError("pass_name must be either 'clean' or 'final'") + verify_str_arg(split, "split", valid_values=("train", "test")) + verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final")) root = Path(root) / "Sintel" @@ -171,8 +170,7 @@ class KittiFlow(FlowDataset): def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) - if split not in ("train", "test"): - raise ValueError("split must be either 'train' or 'test'") + verify_str_arg(split, "split", valid_values=("train", "test")) root = Path(root) / "Kitti" / (split + "ing") images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) @@ -208,6 +206,71 @@ def _read_flow(self, file_name): return _read_16bits_png_with_flow_and_valid_mask(file_name) +class FlyingChairs(FlowDataset): + """`FlyingChairs `_ Dataset for optical flow. + + You will also need to download the FlyingChairs_train_val.txt file from the dataset page. + + The dataset is expected to have the following structure: :: + + root + FlyingChairs + data + 00001_flow.flo + 00001_img1.ppm + 00001_img2.ppm + ... + FlyingChairs_train_val.txt + + + Args: + root (string): Root directory of the FlyingChairs Dataset. + split (string, optional): The dataset split, either "train" (default) or "val" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + ``valid`` is expected for consistency with other datasets which + return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + """ + + def __init__(self, root, split="train", transforms=None): + super().__init__(root=root, transforms=transforms) + + verify_str_arg(split, "split", valid_values=("train", "val")) + + root = Path(root) / "FlyingChairs" + images = sorted(glob(str(root / "data" / "*.ppm"))) + flows = sorted(glob(str(root / "data" / "*.flo"))) + + split_file_name = "FlyingChairs_train_val.txt" + + if not os.path.exists(root / split_file_name): + raise FileNotFoundError( + "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)." + ) + + split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32) + for i in range(len(flows)): + split_id = split_list[i] + if (split == "train" and split_id == 1) or (split == "val" and split_id == 2): + self._flow_list += [flows[i]] + self._image_list += [[images[2 * i], images[2 * i + 1]]] + + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img1, img2, flow)``. + The flow is a numpy array of shape (2, H, W) and the images are PIL images. + """ + return super().__getitem__(index) + + def _read_flow(self, file_name): + return _read_flo(file_name) + + def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: