Skip to content

Commit c178e44

Browse files
authored
Merge branch 'main' into main
2 parents 7cb9dd4 + e0c5cc4 commit c178e44

File tree

5 files changed

+104
-46
lines changed

5 files changed

+104
-46
lines changed

references/classification/sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler):
1515
https://github.com/facebookresearch/deit/blob/main/samplers.py
1616
"""
1717

18-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
18+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
1919
if num_replicas is None:
2020
if not dist.is_available():
2121
raise RuntimeError("Requires distributed package to be available!")
@@ -32,11 +32,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
3232
self.total_size = self.num_samples * self.num_replicas
3333
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
3434
self.shuffle = shuffle
35+
self.seed = seed
3536

3637
def __iter__(self):
3738
# Deterministically shuffle based on epoch
3839
g = torch.Generator()
39-
g.manual_seed(self.epoch)
40+
g.manual_seed(self.seed + self.epoch)
4041
if self.shuffle:
4142
indices = torch.randperm(len(self.dataset), generator=g).tolist()
4243
else:

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision
1010
import transforms
1111
import utils
12-
from references.classification.sampler import RASampler
12+
from sampler import RASampler
1313
from torch import nn
1414
from torch.utils.data.dataloader import default_collate
1515
from torchvision.transforms.functional import InterpolationMode

references/optical_flow/README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,18 @@ torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-siz
4848
```
4949

5050
This should give an epe of about 1.3822 on the clean pass and 2.7161 on the
51-
final pass of Sintel. Results may vary slightly depending on the batch size and
52-
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`:
51+
final pass of Sintel-train. Results may vary slightly depending on the batch
52+
size and the number of GPUs. For the most accurate resuts use 1 GPU and
53+
`--batch-size 1`:
5354

5455
```
5556
Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248
5657
Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964
5758
```
59+
60+
You can also evaluate on Kitti train:
61+
62+
```
63+
torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained
64+
Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679
65+
```

torchvision/models/optical_flow/raft.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
)
2121

2222

23-
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"}
23+
_MODELS_URLS = {
24+
"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
25+
"raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
26+
}
2427

2528

2629
class ResidualBlock(nn.Module):
@@ -587,7 +590,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
587590
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
588591
589592
Args:
590-
pretrained (bool): TODO not implemented yet
593+
pretrained (bool): Whether to use pretrained weights.
591594
progress (bool): If True, displays a progress bar of the download to stderr
592595
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
593596
to override any default.
@@ -632,7 +635,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
632635
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
633636
634637
Args:
635-
pretrained (bool): TODO not implemented yet
638+
pretrained (bool): Whether to use pretrained weights.
636639
progress (bool): If True, displays a progress bar of the download to stderr
637640
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
638641
to override any default.
@@ -641,8 +644,6 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
641644
nn.Module: The model.
642645
643646
"""
644-
if pretrained:
645-
raise ValueError("No checkpoint is available for raft_small")
646647

647648
return _raft(
648649
arch="raft_small",

torchvision/prototype/models/optical_flow/raft.py

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class Raft_Large_Weights(WeightsEnum):
3434
"recipe": "https://github.com/princeton-vl/RAFT",
3535
"sintel_train_cleanpass_epe": 1.4411,
3636
"sintel_train_finalpass_epe": 2.7894,
37+
"kitti_train_per_image_epe": 5.0172,
38+
"kitti_train_f1-all": 17.4506,
3739
},
3840
)
3941

@@ -46,48 +48,94 @@ class Raft_Large_Weights(WeightsEnum):
4648
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
4749
"sintel_train_cleanpass_epe": 1.3822,
4850
"sintel_train_finalpass_epe": 2.7161,
51+
"kitti_train_per_image_epe": 4.5118,
52+
"kitti_train_f1-all": 16.0679,
4953
},
5054
)
5155

52-
# C_T_SKHT_V1 = Weights(
53-
# # Chairs + Things + Sintel fine-tuning, i.e.:
54-
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
55-
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
56-
# url="",
57-
# transforms=RaftEval,
58-
# meta={
59-
# "recipe": "",
60-
# "epe": -1234,
61-
# },
62-
# )
63-
64-
# C_T_SKHT_K_V1 = Weights(
65-
# # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
66-
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
67-
# # Same as CT_SKHT with extra fine-tuning on Kitti
68-
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
69-
# url="",
70-
# transforms=RaftEval,
71-
# meta={
72-
# "recipe": "",
73-
# "epe": -1234,
74-
# },
75-
# )
56+
C_T_SKHT_V1 = Weights(
57+
# Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
58+
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
59+
transforms=RaftEval,
60+
meta={
61+
**_COMMON_META,
62+
"recipe": "https://github.com/princeton-vl/RAFT",
63+
"sintel_test_cleanpass_epe": 1.94,
64+
"sintel_test_finalpass_epe": 3.18,
65+
},
66+
)
67+
68+
C_T_SKHT_V2 = Weights(
69+
# Chairs + Things + Sintel fine-tuning, i.e.:
70+
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
71+
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
72+
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
73+
transforms=RaftEval,
74+
meta={
75+
**_COMMON_META,
76+
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
77+
"sintel_test_cleanpass_epe": 1.819,
78+
"sintel_test_finalpass_epe": 3.067,
79+
},
80+
)
81+
82+
C_T_SKHT_K_V1 = Weights(
83+
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
84+
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
85+
transforms=RaftEval,
86+
meta={
87+
**_COMMON_META,
88+
"recipe": "https://github.com/princeton-vl/RAFT",
89+
"kitti_test_f1-all": 5.10,
90+
},
91+
)
92+
93+
C_T_SKHT_K_V2 = Weights(
94+
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
95+
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
96+
# Same as CT_SKHT with extra fine-tuning on Kitti
97+
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
98+
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
99+
transforms=RaftEval,
100+
meta={
101+
**_COMMON_META,
102+
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
103+
"kitti_test_f1-all": 5.19,
104+
},
105+
)
76106

77107
default = C_T_V2
78108

79109

80110
class Raft_Small_Weights(WeightsEnum):
81-
pass
82-
# C_T_V1 = Weights(
83-
# url="", # TODO
84-
# transforms=RaftEval,
85-
# meta={
86-
# "recipe": "",
87-
# "epe": -1234,
88-
# },
89-
# )
90-
# default = C_T_V1
111+
C_T_V1 = Weights(
112+
# Chairs + Things, ported from original paper repo (raft-small.pth)
113+
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
114+
transforms=RaftEval,
115+
meta={
116+
**_COMMON_META,
117+
"recipe": "https://github.com/princeton-vl/RAFT",
118+
"sintel_train_cleanpass_epe": 2.1231,
119+
"sintel_train_finalpass_epe": 3.2790,
120+
"kitti_train_per_image_epe": 7.6557,
121+
"kitti_train_f1-all": 25.2801,
122+
},
123+
)
124+
C_T_V2 = Weights(
125+
# Chairs + Things
126+
url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
127+
transforms=RaftEval,
128+
meta={
129+
**_COMMON_META,
130+
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
131+
"sintel_train_cleanpass_epe": 1.9901,
132+
"sintel_train_finalpass_epe": 3.2831,
133+
"kitti_train_per_image_epe": 7.5978,
134+
"kitti_train_f1-all": 25.2369,
135+
},
136+
)
137+
138+
default = C_T_V2
91139

92140

93141
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2))
@@ -140,13 +188,13 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
140188
return model
141189

142190

143-
@handle_legacy_interface(weights=("pretrained", None))
191+
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
144192
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
145193
"""RAFT "small" model from
146194
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
147195
148196
Args:
149-
weights(Raft_Small_weights, optinal): TODO not implemented yet
197+
weights(Raft_Small_weights, optional): pretrained weights to use.
150198
progress (bool): If True, displays a progress bar of the download to stderr
151199
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
152200
to override any default.

0 commit comments

Comments
 (0)