Skip to content

Add raft builders and presets in prototypes #5043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
"""

if pretrained:
raise ValueError("Pretrained weights aren't available yet")
raise ValueError("No checkpoint is available for raft_large")

return _raft(
# Feature encoder
Expand Down Expand Up @@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
"""

if pretrained:
raise ValueError("Pretrained weights aren't available yet")
raise ValueError("No checkpoint is available for raft_small")

return _raft(
# Feature encoder
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .vgg import *
from .vision_transformer import *
from . import detection
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/optical_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .raft import RAFT, raft_large, raft_small
168 changes: 168 additions & 0 deletions torchvision/prototype/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Optional

from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
# from torchvision.prototype.transforms import RaftEval

from .._api import WeightsEnum
# from .._api import Weights
from .._utils import handle_legacy_interface


__all__ = (
"RAFT",
"raft_large",
"raft_small",
)


class Raft_Large_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# # Chairs + Things
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )

# C_T_SKHT_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning, i.e.:
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )

# C_T_SKHT_K_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
# # Same as CT_SKHT with extra fine-tuning on Kitti
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )

# default = C_T_V1


class Raft_Small_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# url="", # TODO
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# default = C_T_V1


@handle_legacy_interface(weights=("pretrained", None))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
weights(Raft_Large_weights, optinal): TODO not implemented yet
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.
"""

if weights is not None:
raise ValueError("No checkpoint is available for raft_large")

weights = Raft_Large_Weights.verify(weights)

return _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_block=ResidualBlock,
context_encoder_norm_layer=BatchNorm2d,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=4,
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Mask predictor
use_mask_predictor=True,
**kwargs,
)


@handle_legacy_interface(weights=("pretrained", None))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
weights(Raft_Small_weights, optinal): TODO not implemented yet
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.

"""

if weights is not None:
raise ValueError("No checkpoint is available for raft_small")

weights = Raft_Small_Weights.verify(weights)

return _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(32, 32, 64, 96, 160),
context_encoder_block=BottleneckBlock,
context_encoder_norm_layer=None,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=3,
# Motion encoder
motion_encoder_corr_layers=(96,),
motion_encoder_flow_layers=(64, 32),
motion_encoder_out_channels=82,
# Recurrent block
recurrent_block_hidden_state_size=96,
recurrent_block_kernel_size=(3,),
recurrent_block_padding=(1,),
# Flow head
flow_head_hidden_size=128,
# Mask predictor
use_mask_predictor=False,
**kwargs,
)
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
from ._misc import Identity, Normalize
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
35 changes: 35 additions & 0 deletions torchvision/prototype/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,38 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor,
target = F.pil_to_tensor(target)
target = target.squeeze(0).to(torch.int64)
return img, target


class RaftEval(nn.Module):
def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:

img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)

img1 = F.convert_image_dtype(img1, torch.float32)
img2 = F.convert_image_dtype(img2, torch.float32)

# map [0, 1] into [-1, 1]
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

img1 = img1.contiguous()
img2 = img2.contiguous()

return img1, img2, flow, valid_flow_mask

def _pil_or_numpy_to_tensor(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
if not isinstance(img1, Tensor):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing that we type the input as Tensor and actually handle the case where it's not a tensor. I saw that on the other presets so I did the same. I assume that this is only temporary to check these presets on the current transforms (which return PIL images), and that we will remove the conversions eventually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It's also because the user is supposed to use these transforms during inference. At that point, you don't know if they chose to read the image with PIL or with TV's io. So here we support both.

This is also done because the reference scripts for other tasks only support PIL. BTW now that you added a prototype section for your model, you should add a support for it in your reference scripts on the other PR.

img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)

if flow is not None and not isinstance(flow, Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)

return img1, img2, flow, valid_flow_mask