Open
Description
Pytorch 2.3 is introducing unsigned integer dtypes like uint16
, uint32
and uint64
in pytorch/pytorch#116594.
Quoting Ed:
The dtypes are very useless right now (not even fill works), but it makes torch.uint16, uint32 and uint64 available as a dtype.
I tried uint16
on some of the transforms and the following would work:
x = torch.randint(0, 256, size=(1, 3, 10, 10), dtype=torch.uint16)
transforms = T.Compose(
[
T.Pad(2),
T.Resize(5),
T.CenterCrop(3),
# T.RandomHorizontalFlip(p=1),
# T.ColorJitter(2, 2, 2, .1),
T.ToDtype(torch.float32, scale=True),
]
)
transforms(x)
but stuff like flip or colorjitter won't work. In general, it's safe to assume that uint16 doesn't really work on eager.
What to do about F.to_tensor()
and F.pil_to_tensor()
.
Up until 2.3, passing a unit16 PIL image (mode = "I;16") to those would produce:
to_tensor()
: anint16
tensor as ouput for. This is completely wrong and a bug: the range ofint16
is smaller thanuint16
, so the resulting tensor is incorrect and has tons of negative value (coming from overflow).pil_to_tensor()
: an error - this is OK.
Now with 2.3 (or more precisely with the nightlies/RC):
to_tensor()
: still outputs an int16 tensor which is still incorrectpil_to_tensor()
outputs a uint16 tensor which is correct - but that tensor won't work with a lot of the transforms.
Proposed fix
- Keep
pil_to_tensor()
as-is, just write a few additional tests w.r.t. uint16 support - Make
to_tensor()
return uint16 tensor instead of int16. This is a bug fix. Users may get loud errors down the line when they're using that uint16 on transforms (because uint16 is generally not well supported), but a loud error is much better than a silent error, which is what users were currently getting - Document in https://pytorch.org/vision/main/transforms.html#supported-input-types-and-conventions that uint16, uint32 and uint64 aren't officially supported by the torchvision transforms - most users should stick to uint8 or float.
Dirty notebook to play with:
%
%load_ext autoreload
%autoreload 2
import numpy as np
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F
from PIL import Image
import torch
torch.__version__
#%%
x = torch.randint(100, (512, 512), dtype=torch.int16)
#%%
x_pil = F.to_pil_image(x)
x_pil.mode # I;16
#%%
F.pil_to_tensor(x_pil).dtype # torch.uint16
# %%
F.to_tensor(x_pil).dtype # torch.int16
# %%
x = np.random.randint(0, np.iinfo(np.uint16).max, (10, 10), dtype=np.uint16)
x_pil = Image.fromarray(x, mode="I;16")
x_pil.mode # I;16
# %%
F.pil_to_tensor(x_pil).dtype # torch.uint16
# %%
torch.testing.assert_close(torch.from_numpy(x)[None], F.pil_to_tensor(x_pil))
# %%
F.to_tensor(x_pil).dtype # torch.int16
# %%
torch.testing.assert_close(torch.from_numpy(x)[None].float(), F.to_tensor(x_pil).float())
# %%
x = torch.randint(0, 256, size=(1, 3, 10, 10), dtype=torch.uint16)
transforms = T.Compose(
[
T.Pad(2),
T.Resize(5),
T.CenterCrop(3),
# T.RandomHorizontalFlip(p=1),
# T.ColorJitter(2, 2, 2, .1),
T.ToDtype(torch.float32, scale=True),
]
)
transforms(x)
#
Metadata
Metadata
Assignees
Labels
No labels