@@ -34,6 +34,8 @@ class Raft_Large_Weights(WeightsEnum):
34
34
"recipe" : "https://github.com/princeton-vl/RAFT" ,
35
35
"sintel_train_cleanpass_epe" : 1.4411 ,
36
36
"sintel_train_finalpass_epe" : 2.7894 ,
37
+ "kitti_train_per_image_epe" : 5.0172 ,
38
+ "kitti_train_f1-all" : 17.4506 ,
37
39
},
38
40
)
39
41
@@ -46,48 +48,94 @@ class Raft_Large_Weights(WeightsEnum):
46
48
"recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
47
49
"sintel_train_cleanpass_epe" : 1.3822 ,
48
50
"sintel_train_finalpass_epe" : 2.7161 ,
51
+ "kitti_train_per_image_epe" : 4.5118 ,
52
+ "kitti_train_f1-all" : 16.0679 ,
49
53
},
50
54
)
51
55
52
- # C_T_SKHT_V1 = Weights(
53
- # # Chairs + Things + Sintel fine-tuning, i.e.:
54
- # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
55
- # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
56
- # url="",
57
- # transforms=RaftEval,
58
- # meta={
59
- # "recipe": "",
60
- # "epe": -1234,
61
- # },
62
- # )
63
-
64
- # C_T_SKHT_K_V1 = Weights(
65
- # # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
66
- # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
67
- # # Same as CT_SKHT with extra fine-tuning on Kitti
68
- # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
69
- # url="",
70
- # transforms=RaftEval,
71
- # meta={
72
- # "recipe": "",
73
- # "epe": -1234,
74
- # },
75
- # )
56
+ C_T_SKHT_V1 = Weights (
57
+ # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
58
+ url = "https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth" ,
59
+ transforms = RaftEval ,
60
+ meta = {
61
+ ** _COMMON_META ,
62
+ "recipe" : "https://github.com/princeton-vl/RAFT" ,
63
+ "sintel_test_cleanpass_epe" : 1.94 ,
64
+ "sintel_test_finalpass_epe" : 3.18 ,
65
+ },
66
+ )
67
+
68
+ C_T_SKHT_V2 = Weights (
69
+ # Chairs + Things + Sintel fine-tuning, i.e.:
70
+ # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
71
+ # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
72
+ url = "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth" ,
73
+ transforms = RaftEval ,
74
+ meta = {
75
+ ** _COMMON_META ,
76
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
77
+ "sintel_test_cleanpass_epe" : 1.819 ,
78
+ "sintel_test_finalpass_epe" : 3.067 ,
79
+ },
80
+ )
81
+
82
+ C_T_SKHT_K_V1 = Weights (
83
+ # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
84
+ url = "https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth" ,
85
+ transforms = RaftEval ,
86
+ meta = {
87
+ ** _COMMON_META ,
88
+ "recipe" : "https://github.com/princeton-vl/RAFT" ,
89
+ "kitti_test_f1-all" : 5.10 ,
90
+ },
91
+ )
92
+
93
+ C_T_SKHT_K_V2 = Weights (
94
+ # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
95
+ # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
96
+ # Same as CT_SKHT with extra fine-tuning on Kitti
97
+ # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
98
+ url = "https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth" ,
99
+ transforms = RaftEval ,
100
+ meta = {
101
+ ** _COMMON_META ,
102
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
103
+ "kitti_test_f1-all" : 5.19 ,
104
+ },
105
+ )
76
106
77
107
default = C_T_V2
78
108
79
109
80
110
class Raft_Small_Weights (WeightsEnum ):
81
- pass
82
- # C_T_V1 = Weights(
83
- # url="", # TODO
84
- # transforms=RaftEval,
85
- # meta={
86
- # "recipe": "",
87
- # "epe": -1234,
88
- # },
89
- # )
90
- # default = C_T_V1
111
+ C_T_V1 = Weights (
112
+ # Chairs + Things, ported from original paper repo (raft-small.pth)
113
+ url = "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth" ,
114
+ transforms = RaftEval ,
115
+ meta = {
116
+ ** _COMMON_META ,
117
+ "recipe" : "https://github.com/princeton-vl/RAFT" ,
118
+ "sintel_train_cleanpass_epe" : 2.1231 ,
119
+ "sintel_train_finalpass_epe" : 3.2790 ,
120
+ "kitti_train_per_image_epe" : 7.6557 ,
121
+ "kitti_train_f1-all" : 25.2801 ,
122
+ },
123
+ )
124
+ C_T_V2 = Weights (
125
+ # Chairs + Things
126
+ url = "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth" ,
127
+ transforms = RaftEval ,
128
+ meta = {
129
+ ** _COMMON_META ,
130
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
131
+ "sintel_train_cleanpass_epe" : 1.9901 ,
132
+ "sintel_train_finalpass_epe" : 3.2831 ,
133
+ "kitti_train_per_image_epe" : 7.5978 ,
134
+ "kitti_train_f1-all" : 25.2369 ,
135
+ },
136
+ )
137
+
138
+ default = C_T_V2
91
139
92
140
93
141
@handle_legacy_interface (weights = ("pretrained" , Raft_Large_Weights .C_T_V2 ))
@@ -140,13 +188,13 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
140
188
return model
141
189
142
190
143
- @handle_legacy_interface (weights = ("pretrained" , None ))
191
+ @handle_legacy_interface (weights = ("pretrained" , Raft_Small_Weights . C_T_V2 ))
144
192
def raft_small (* , weights : Optional [Raft_Small_Weights ] = None , progress = True , ** kwargs ):
145
193
"""RAFT "small" model from
146
194
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
147
195
148
196
Args:
149
- weights(Raft_Small_weights, optinal ): TODO not implemented yet
197
+ weights(Raft_Small_weights, optional ): pretrained weights to use.
150
198
progress (bool): If True, displays a progress bar of the download to stderr
151
199
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
152
200
to override any default.
0 commit comments