-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add transforms and presets for optical flow models #5026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
85f314d
55e9992
706a0a6
647dbb5
02a2640
ccc0029
f6fe16d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import transforms as T | ||
|
||
|
||
class OpticalFlowPresetEval(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.transforms = T.Compose( | ||
[ | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float32), | ||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||
T.ValidateModelInput(), | ||
] | ||
) | ||
|
||
def __call__(self, img1, img2, flow, valid): | ||
return self.transforms(img1, img2, flow, valid) | ||
|
||
|
||
class OpticalFlowPresetTrain(torch.nn.Module): | ||
def __init__( | ||
self, | ||
# MaybeRandomResizeAndCrop params | ||
crop_size, | ||
min_scale=-0.2, | ||
max_scale=0.5, | ||
stretch_prob=0.8, | ||
# AsymmetricColorJitter params | ||
brightness=0.4, | ||
contrast=0.4, | ||
saturation=0.4, | ||
hue=0.5 / 3.14, | ||
# Random[H,V]Flip params | ||
asymmetric_jitter_prob=0.2, | ||
do_flip=True, | ||
): | ||
super().__init__() | ||
|
||
transforms = [ | ||
T.PILToTensor(), | ||
T.AsymmetricColorJitter( | ||
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob | ||
), | ||
T.RandomApply([T.RandomErase()], p=0.5), | ||
T.MaybeResizeAndCrop( | ||
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob | ||
), | ||
] | ||
|
||
if do_flip: | ||
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)] | ||
|
||
transforms += [ | ||
T.ConvertImageDtype(torch.float32), | ||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
T.MakeValidFlowMask(), | ||
T.ValidateModelInput(), | ||
] | ||
self.transforms = T.Compose(transforms) | ||
|
||
def __call__(self, img1, img2, flow, valid): | ||
return self.transforms(img1, img2, flow, valid) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
import torch | ||
import torchvision.transforms as T | ||
import torchvision.transforms.functional as F | ||
|
||
|
||
class ValidateModelInput(torch.nn.Module): | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects | ||
def __init__(self): | ||
super().__init__() | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
|
||
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None) | ||
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None) | ||
|
||
assert img1.shape == img2.shape | ||
h, w = img1.shape[-2:] | ||
if flow is not None: | ||
assert flow.shape == (2, h, w) | ||
if valid_flow_mask is not None: | ||
assert valid_flow_mask.shape == (h, w) | ||
assert valid_flow_mask.dtype == torch.bool | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class MakeValidFlowMask(torch.nn.Module): | ||
# This transform generates a valid_flow_mask if it doesn't exist. | ||
# The flow is considered valid if ||flow||_inf < threshold | ||
# This is a noop for Kitti and HD1K which already come with a built-in flow mask. | ||
def __init__(self, threshold=1000): | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if flow is not None and valid_flow_mask is None: | ||
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class ConvertImageDtype(torch.nn.Module): | ||
def __init__(self, dtype): | ||
super().__init__() | ||
self.dtype = dtype | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.convert_image_dtype(img1, dtype=self.dtype) | ||
img2 = F.convert_image_dtype(img2, dtype=self.dtype) | ||
|
||
img1 = img1.contiguous() | ||
img2 = img2.contiguous() | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class Normalize(torch.nn.Module): | ||
def __init__(self, mean, std): | ||
super().__init__() | ||
self.mean = mean | ||
self.std = std | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.normalize(img1, mean=self.mean, std=self.std) | ||
img2 = F.normalize(img2, mean=self.mean, std=self.std) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class PILToTensor(torch.nn.Module): | ||
# Converts all inputs to tensors | ||
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming | ||
# for consistency with the rest, e.g. the segmentation reference. | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.pil_to_tensor(img1) | ||
img2 = F.pil_to_tensor(img2) | ||
if flow is not None: | ||
flow = torch.from_numpy(flow) | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = torch.from_numpy(valid_flow_mask) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class AsymmetricColorJitter(T.ColorJitter): | ||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2): | ||
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) | ||
self.p = p | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
|
||
if torch.rand(1) < self.p: | ||
# asymmetric: different transform for img1 and img2 | ||
img1 = super().forward(img1) | ||
img2 = super().forward(img2) | ||
else: | ||
# symmetric: same transform for img1 and img2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug: So does the @pmeier: Could you please check this strange transform to confirm it's supported by the new Transforms API? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As it stands, this would not be supported. A transform always treats a sample as atomic unit and so multiple images in the same sample would be transformed with the same parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I'll clarify. Ultimately this is a special case of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug Sounds good, just add comments. No need to use RandomApply here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pmeier No worries, this is why we give the option for someone to write custom transforms without the magic of the new API. For weird cases like this. Could you now confirm that this is indeed a workaround we can apply? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm guessing I think this is one of the cases @datumbox mentioned where we need to circumvent the automatic dispatch a little. In case we want to transform both samples separately, we could split the sample and and perform the transformation once for the sample minus image 2 and once for image2. The problem I see with this, is that it can't be automated without assumptions about how the sample is structured. So we either need to use the same structure for every dataset (for example flat dictionary with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Each transform would receive the entire input (which IIRC is a dict) and operate on a subset of that dict. Are you suggesting that img1 and img2 would be concatenated? |
||
batch = torch.stack([img1, img2]) | ||
batch = super().forward(batch) | ||
img1, img2 = batch[0], batch[1] | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomErase(torch.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is similar to our existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that this class is supposed to be used BEFORE converting the image from int to floats. This is unlike TorchVision's Concerning using mean VS a fixed value, note that TorchVision's RandomErasing implementation uses a fixed zero value because the images are expected to be de-meaned. I wonder if it would be possible to reuse for now TorchVision's transform if you passed a normalized image. Perhaps that's worth doing even if the outcome is not 100% the same, just to maintain parity between transforms defined in references and those in legacy transforms (this will simplify porting to the new API). |
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
bounds = [50, 100] | ||
ht, wd = img2.shape[:2] | ||
|
||
# Warning : This won't work with image values in [0, 1] because of round() | ||
mean_color = img2.view(3, -1).float().mean(axis=-1).round() | ||
for _ in range(torch.randint(1, 3, size=(1,)).item()): | ||
x0 = torch.randint(0, wd, size=(1,)).item() | ||
y0 = torch.randint(0, ht, size=(1,)).item() | ||
dx, dy = torch.randint(bounds[0], bounds[1], size=(2,)) | ||
img2[:, y0 : y0 + dy, x0 : x0 + dx] = mean_color[:, None, None] | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomHorizontalFlip(T.RandomHorizontalFlip): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
img1 = F.hflip(img1) | ||
img2 = F.hflip(img2) | ||
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None] | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.hflip(valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomVerticalFlip(T.RandomVerticalFlip): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
img1 = F.vflip(img1) | ||
img2 = F.vflip(img2) | ||
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None] | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.vflip(valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class MaybeResizeAndCrop(torch.nn.Module): | ||
# This transform will resize the input with a given proba, and then crop it. | ||
# These are the reversed operations of the built-in RandomResizedCrop, | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# although the order of the operations doesn't matter too much. | ||
# The reason we don't rely on RandomResizedCrop is because of a significant | ||
# difference in the parametrization of both transforms. | ||
# | ||
# There *is* a mapping between the inputs of MaybeResizeAndCrop and those of | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# RandomResizedCrop, but the issue is that the parameters are sampled at | ||
# random, with different distributions. Plotting (the equivalent of) `scale` | ||
# and `ratio` from MaybeResizeAndCrop shows that the distributions of these | ||
# parameters are very different from what can be obtained from the | ||
# parametrization of RandomResizedCrop. I tried training RAFT by using | ||
# RandomResizedCrop and tweaking the parameters a bit, but couldn't get | ||
# an epe as good as with MaybeResizeAndCrop. | ||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8): | ||
super().__init__() | ||
self.crop_size = crop_size | ||
self.min_scale = min_scale | ||
self.max_scale = max_scale | ||
self.stretch_prob = stretch_prob | ||
self.resize_prob = 0.8 | ||
self.max_stretch = 0.2 | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
# randomly sample scale | ||
h, w = img1.shape[-2:] | ||
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti) | ||
# It shouldn't matter much | ||
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w) | ||
|
||
scale = 2 ** torch.FloatTensor(1).uniform_(self.min_scale, self.max_scale).item() | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
scale_x = scale | ||
scale_y = scale | ||
if torch.rand(1) < self.stretch_prob: | ||
scale_x *= 2 ** torch.FloatTensor(1).uniform_(-self.max_stretch, self.max_stretch).item() | ||
scale_y *= 2 ** torch.FloatTensor(1).uniform_(-self.max_stretch, self.max_stretch).item() | ||
|
||
scale_x = max(scale_x, min_scale) | ||
scale_y = max(scale_y, min_scale) | ||
|
||
new_h, new_w = round(h * scale_y), round(w * scale_x) | ||
|
||
if torch.rand(1).item() < self.resize_prob: | ||
# rescale the images | ||
img1 = F.resize(img1, size=(new_h, new_w)) | ||
img2 = F.resize(img2, size=(new_h, new_w)) | ||
if valid_flow_mask is None: | ||
flow = F.resize(flow, size=(new_h, new_w)) | ||
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None] | ||
else: | ||
flow, valid_flow_mask = self._resize_sparse_flow( | ||
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y | ||
) | ||
|
||
# Note: For sparse datasets (Kitti), the original code uses a "margin" | ||
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220 | ||
# We don't, not sure it matters much | ||
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item() | ||
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item() | ||
|
||
img1 = img1[:, y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||
img2 = img2[:, y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||
flow = flow[:, y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if valid_flow_mask is not None: | ||
valid_flow_mask = valid_flow_mask[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0): | ||
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse) | ||
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB) | ||
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4 | ||
|
||
h, w = flow.shape[-2:] | ||
|
||
h_new = int(round(h * scale_y)) | ||
w_new = int(round(w * scale_x)) | ||
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype) | ||
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype) | ||
|
||
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy") | ||
|
||
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask] | ||
|
||
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long) | ||
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long) | ||
|
||
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new) | ||
|
||
ii_valid = ii_valid[within_bounds_mask] | ||
jj_valid = jj_valid[within_bounds_mask] | ||
ii_valid_new = ii_valid_new[within_bounds_mask] | ||
jj_valid_new = jj_valid_new[within_bounds_mask] | ||
|
||
valid_flow_new = flow[:, ii_valid, jj_valid] | ||
valid_flow_new[0] *= scale_x | ||
valid_flow_new[1] *= scale_y | ||
|
||
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new | ||
valid_new[ii_valid_new, jj_valid_new] = 1 | ||
|
||
return flow_new, valid_new | ||
|
||
|
||
class RandomApply(T.RandomApply): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if self.p < torch.rand(1): | ||
return img1, img2, flow, valid_flow_mask | ||
for t in self.transforms: | ||
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class Compose: | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, transforms): | ||
self.transforms = transforms | ||
|
||
def __call__(self, img1, img2, flow, valid_flow_mask): | ||
for t in self.transforms: | ||
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask |
Uh oh!
There was an error while loading. Please reload this page.