-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add pretrained weights on Chairs and Things for raft_large #5060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
dce22b3
Add pretrained weights on Chairs and Things for raft_large
NicolasHug 228a17c
Merge branch 'main' of github.com:pytorch/vision into raft_pretrained_CT
NicolasHug d244401
Minor stuff
NicolasHug 0406c83
Merge branch 'main' of github.com:pytorch/vision into raft_pretrained_CT
NicolasHug f186973
Add pretrained weights from paper's repo as V1
NicolasHug 57aff36
Address comments
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Optical flow reference training scripts | ||
|
||
This folder contains reference training scripts for optical flow. | ||
They serve as a log of how to train specific models, so as to provide baseline | ||
training and evaluation scripts to quickly bootstrap research. | ||
|
||
|
||
### RAFT Large | ||
|
||
The RAFT large model was trained on Flying Chairs and then on Flying Things. | ||
Both used 8 A100 GPUs and a batch size of 2 (so effective batch size is 16). The | ||
rest of the hyper-parameters are exactly the same as the original RAFT training | ||
recipe from https://github.com/princeton-vl/RAFT. | ||
|
||
``` | ||
torchrun --nproc_per_node 8 --nnodes 1 train.py \ | ||
--dataset-root $dataset_root \ | ||
--name $name_chairs \ | ||
--train-dataset chairs \ | ||
--batch-size 2 \ | ||
--lr 0.0004 \ | ||
--weight-decay 0.0001 \ | ||
--num-steps 100000 \ | ||
--output-dir $chairs_dir | ||
``` | ||
|
||
``` | ||
torchrun --nproc_per_node 8 --nnodes 1 train.py \ | ||
--dataset-root $dataset_root \ | ||
--name $name_things \ | ||
--train-dataset things \ | ||
--batch-size 2 \ | ||
--lr 0.000125 \ | ||
--weight-decay 0.0001 \ | ||
--num-steps 100000 \ | ||
--freeze-batch-norm \ | ||
--output-dir $things_dir\ | ||
--resume $chairs_dir/$name_chairs.pth | ||
``` | ||
|
||
|
||
### Evaluation | ||
|
||
``` | ||
torchrun --nproc_per_node 8 --nnodes 1 train.py --val-dataset sintel --batch-size 10 --dataset-root $dataset_root --model raft_large --pretrained | ||
``` | ||
|
||
This should give an epe of about 1.3825 on the clean pass and 2.7148 on the | ||
final pass of Sintel. Results may vary slightly depending on the batch size and | ||
the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`. | ||
|
||
``` | ||
Sintel val clean epe: 1.3825 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3782 f1: 4.0234 | ||
Sintel val final epe: 2.7148 1px: 0.8526 3px: 0.9203 5px: 0.9392 per_image_epe: 2.7199 f1: 7.6100 | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from torch.nn.modules.instancenorm import InstanceNorm2d | ||
from torchvision.ops import ConvNormActivation | ||
|
||
from ..._internally_replaced_utils import load_state_dict_from_url | ||
from ...utils import _log_api_usage_once | ||
from ._utils import grid_sample, make_coords_grid, upsample_flow | ||
|
||
|
@@ -19,6 +20,9 @@ | |
) | ||
|
||
|
||
_MODELS_URLS = {"raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once PR is merged I will upload this to manifold There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: all current models use |
||
|
||
|
||
class ResidualBlock(nn.Module): | ||
"""Slightly modified Residual block with extra relu and biases.""" | ||
|
||
|
@@ -474,8 +478,8 @@ def forward(self, image1, image2, num_flow_updates: int = 12): | |
hidden_state = torch.tanh(hidden_state) | ||
context = F.relu(context) | ||
|
||
coords0 = make_coords_grid(batch_size, h // 8, w // 8).cuda() | ||
coords1 = make_coords_grid(batch_size, h // 8, w // 8).cuda() | ||
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) | ||
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) | ||
|
||
flow_predictions = [] | ||
for _ in range(num_flow_updates): | ||
|
@@ -496,6 +500,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12): | |
|
||
def _raft( | ||
*, | ||
arch=None, | ||
pretrained=False, | ||
progress=False, | ||
# Feature encoder | ||
feature_encoder_layers, | ||
feature_encoder_block, | ||
|
@@ -560,14 +567,19 @@ def _raft( | |
multiplier=0.25, # See comment in MaskPredictor about this | ||
) | ||
|
||
return RAFT( | ||
model = RAFT( | ||
feature_encoder=feature_encoder, | ||
context_encoder=context_encoder, | ||
corr_block=corr_block, | ||
update_block=update_block, | ||
mask_predictor=mask_predictor, | ||
**kwargs, # not really needed, all params should be consumed by now | ||
) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) | ||
model.load_state_dict(state_dict) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return model | ||
|
||
|
||
def raft_large(*, pretrained=False, progress=True, **kwargs): | ||
|
@@ -584,10 +596,10 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): | |
nn.Module: The model. | ||
""" | ||
|
||
if pretrained: | ||
raise ValueError("No checkpoint is available for raft_large") | ||
|
||
return _raft( | ||
arch="raft_large", | ||
pretrained=pretrained, | ||
progress=progress, | ||
# Feature encoder | ||
feature_encoder_layers=(64, 64, 96, 128, 256), | ||
feature_encoder_block=ResidualBlock, | ||
|
@@ -629,11 +641,13 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): | |
nn.Module: The model. | ||
|
||
""" | ||
|
||
if pretrained: | ||
raise ValueError("No checkpoint is available for raft_small") | ||
|
||
return _raft( | ||
arch="raft_small", | ||
pretrained=pretrained, | ||
progress=progress, | ||
# Feature encoder | ||
feature_encoder_layers=(32, 32, 64, 96, 128), | ||
feature_encoder_block=BottleneckBlock, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.