Skip to content

[RFC] Add scriptable transforms #1375

Closed
@fmassa

Description

@fmassa

TL;DR

I propose that we have a separate set of functional transforms that takes a tensors as input, and returns tensors, and it should be torchscript-able.

Background

TorchVision currently relies on PIL for most of its transforms.
While reasonably fast and widely adopted, the use of an external library it makes our transforms impossible to be traceable / scriptable.

One of the biggest drawbacks of that is that pre-processing is generally a crucial part of reproducing a models' results, and different preprocessing (due to, e.g., OpenCV / PIL differences) can have an impact in the final model result.

By the time torchvision was initially developed, there were way fewer operations implemented on PyTorch that could be used to perform image transformations, such as resizing, rotations and affine warps.
It also creates a kind of weird situation where certain operations expect PIL Images, and others expect Torch Tensors (normalize is a notable case).

Since then, we have improved the support for image resizing in PyTorch (thanks to the upsample function), which supports a number of cases, as well as grid_sample, which enables us to do rotations, affine warpings and more in an efficient manner.

Pros of using PyTorch ops

  • GPU support
  • Batching supported
  • Enables tracing the transforms
  • autodiff support

Cons of using PyTorch ops

  • Not bit-wise equivalent to PIL
  • Some (but not many) cases are not yet supported

It should be noted that using PyTorch ops should not be a hard-constraints. This lets the users still implement their own functionalities by leveraging PIL or OpenCV. But only the transforms based on PyTorch will be able to be exported to torchscript.

This means that the lingua-franca of passing objects around in torchvision transforms would be a torch.Tensor, and not a PIL Image anymore.

How to implement it

Most of the transforms in torchvision can already be expressed with PyTorch native operators, like torch.nn.functional.interpolate or torch.nn.functional.grid_sample, so we should not need to write specialized ops for them in torchvision.

An initial PR adding support for video has been sent in #1353 , and I think we should improve on top of it to make it cover more ops, and also support images.

Gotchas

Using torch operators has a drawback. It currently only supports batched tensors in NCHW format and floating point values, which is different than the format supported by our current set of transforms (HWC and uint8 for most cases).

For now let's assume that the tensors are float32 and in the NCHW format. We might consider explicitly keeping a memory_format=torch.channels_last layout for compatibility (TBD)

Long-term, we should add support for uint8 (and other integer types) to interpolate and make it more generic over which dimensions to interpolate pytorch/pytorch#10482, but that's a larger task.

List of transforms that could be readily available with PyTorch ops

  • normalize
  • resize (only nearest, bilinear and bicubic, for floating types)
  • pad (except symmetric pad)
  • crop
  • center_crop
  • resized_crop
  • hflip
  • vflip
  • five_crop
  • ten_crop
  • adjust_brightness
  • adjust_contrast
  • adjust_saturation
  • adjust_hue
  • adjust_gamma
  • rotate (only for nearest and bilinear, for floating types)
  • affine (only for nearest and bilinear, for floating types)
  • grayscale

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions