diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4bb18cf6b48..4d95d314669 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -91,7 +91,8 @@ the tensor dtype. Tensor images with a float dtype are expected to have values in ``[0, 1]``. Tensor images with an integer dtype are expected to have values in ``[0, MAX_DTYPE]`` where ``MAX_DTYPE`` is the largest value that can be represented in that dtype. Typically, images of dtype -``torch.uint8`` are expected to have values in ``[0, 255]``. +``torch.uint8`` are expected to have values in ``[0, 255]``. Note that dtypes +like ``torch.uint16`` or ``torch.uint32`` aren't fully supported. Use :class:`~torchvision.transforms.v2.ToDtype` to convert both the dtype and range of the inputs. diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f9218c3e840..32acc5d6930 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5478,6 +5478,18 @@ def test_functional_error(self): F.pil_to_tensor(object()) +@pytest.mark.parametrize("f", [F.to_tensor, F.pil_to_tensor]) +def test_I16_to_tensor(f): + # See https://github.com/pytorch/vision/issues/8359 + I16_pil_img = PIL.Image.fromarray(np.random.randint(0, 2**16, (10, 10), dtype=np.uint16)) + assert I16_pil_img.mode == "I;16" + + cm = pytest.warns(UserWarning, match="deprecated") if f is F.to_tensor else contextlib.nullcontext() + with cm: + out = f(I16_pil_img) + assert out.dtype == torch.uint16 + + class TestLambda: @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)]) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8efe2a8878a..09edfeea35e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -164,7 +164,7 @@ def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor: return torch.from_numpy(nppic).to(dtype=default_float_dtype) # handle PIL Image - mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32} + mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.uint16, "F": np.float32} img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)) if pic.mode == "1":