-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fill color support for tensor affine transforms #2904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
696c15a
229c140
adae0f6
b2721e8
1c4e48a
62abb37
a585dbd
d616210
417f6ea
50d311d
6b0eb53
5589c14
731a5a9
4389f80
2ea1003
4c59964
9e7cb7a
16e9b97
96c70bc
9d9fd08
87560cb
bc7e9fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -552,24 +552,24 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine): | |
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): | ||
# 4) Test rotation + translation + scale + share | ||
test_configs = [ | ||
(45, [5, 6], 1.0, [0.0, 0.0]), | ||
(33, (5, -4), 1.0, [0.0, 0.0]), | ||
(45, [-5, 4], 1.2, [0.0, 0.0]), | ||
(33, (-4, -8), 2.0, [0.0, 0.0]), | ||
(85, (10, -10), 0.7, [0.0, 0.0]), | ||
(0, [0, 0], 1.0, [35.0, ]), | ||
(-25, [0, 0], 1.2, [0.0, 15.0]), | ||
(-45, [-10, 0], 0.7, [2.0, 5.0]), | ||
(-45, [-10, -10], 1.2, [4.0, 5.0]), | ||
(-90, [0, 0], 1.0, [0.0, 0.0]), | ||
(45.5, [5, 6], 1.0, [0.0, 0.0], None), | ||
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), | ||
(45, [-5, 4], 1.2, [0.0, 0.0], [1, 2, 3]), | ||
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), | ||
(85, (10, -10), 0.7, [0.0, 0.0], None), | ||
(0, [0, 0], 1.0, [35.0, ], None), | ||
(-25, [0, 0], 1.2, [0.0, 15.0], None), | ||
(-45, [-10, 0], 0.7, [2.0, 5.0], None), | ||
(-45, [-10, -10], 1.2, [4.0, 5.0], None), | ||
(-90, [0, 0], 1.0, [0.0, 0.0], None), | ||
] | ||
for r in [NEAREST, ]: | ||
for a, t, s, sh in test_configs: | ||
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r) | ||
for a, t, s, sh, f in test_configs: | ||
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f) | ||
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) | ||
|
||
for fn in [F.affine, scripted_affine]: | ||
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu() | ||
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu() | ||
|
||
if out_tensor.dtype != torch.uint8: | ||
out_tensor = out_tensor.to(torch.uint8) | ||
|
@@ -582,7 +582,7 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): | |
ratio_diff_pixels, | ||
tol, | ||
msg="{}: {}\n{} vs \n{}".format( | ||
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] | ||
(r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] | ||
) | ||
) | ||
|
||
|
@@ -643,35 +643,35 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): | |
for a in range(-180, 180, 17): | ||
for e in [True, False]: | ||
for c in centers: | ||
|
||
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c) | ||
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) | ||
for fn in [F.rotate, scripted_rotate]: | ||
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu() | ||
|
||
if out_tensor.dtype != torch.uint8: | ||
out_tensor = out_tensor.to(torch.uint8) | ||
|
||
self.assertEqual( | ||
out_tensor.shape, | ||
out_pil_tensor.shape, | ||
msg="{}: {} vs {}".format( | ||
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape | ||
) | ||
) | ||
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 | ||
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] | ||
# Tolerance : less than 3% of different pixels | ||
self.assertLess( | ||
ratio_diff_pixels, | ||
0.03, | ||
msg="{}: {}\n{} vs \n{}".format( | ||
(img_size, r, dt, a, e, c), | ||
for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, let's add single int and float values (as |
||
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f) | ||
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) | ||
for fn in [F.rotate, scripted_rotate]: | ||
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu() | ||
|
||
if out_tensor.dtype != torch.uint8: | ||
out_tensor = out_tensor.to(torch.uint8) | ||
|
||
self.assertEqual( | ||
out_tensor.shape, | ||
out_pil_tensor.shape, | ||
msg="{}: {} vs {}".format( | ||
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape | ||
)) | ||
|
||
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 | ||
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] | ||
# Tolerance : less than 3% of different pixels | ||
self.assertLess( | ||
ratio_diff_pixels, | ||
out_tensor[0, :7, :7], | ||
out_pil_tensor[0, :7, :7] | ||
0.03, | ||
msg="{}: {}\n{} vs \n{}".format( | ||
(img_size, r, dt, a, e, c, f), | ||
ratio_diff_pixels, | ||
out_tensor[0, :7, :7], | ||
out_pil_tensor[0, :7, :7] | ||
) | ||
) | ||
) | ||
|
||
def test_rotate(self): | ||
# Tests on square image | ||
|
@@ -721,30 +721,32 @@ def test_rotate(self): | |
|
||
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): | ||
dt = tensor.dtype | ||
for r in [NEAREST, ]: | ||
for spoints, epoints in test_configs: | ||
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) | ||
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) | ||
for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255]]: | ||
for r in [NEAREST, ]: | ||
for spoints, epoints in test_configs: | ||
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r, | ||
fill=f) | ||
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) | ||
|
||
for fn in [F.perspective, scripted_transform]: | ||
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() | ||
for fn in [F.perspective, scripted_transform]: | ||
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu() | ||
|
||
if out_tensor.dtype != torch.uint8: | ||
out_tensor = out_tensor.to(torch.uint8) | ||
if out_tensor.dtype != torch.uint8: | ||
out_tensor = out_tensor.to(torch.uint8) | ||
|
||
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 | ||
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] | ||
# Tolerance : less than 5% of different pixels | ||
self.assertLess( | ||
ratio_diff_pixels, | ||
0.05, | ||
msg="{}: {}\n{} vs \n{}".format( | ||
(r, dt, spoints, epoints), | ||
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 | ||
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] | ||
# Tolerance : less than 5% of different pixels | ||
self.assertLess( | ||
ratio_diff_pixels, | ||
out_tensor[0, :7, :7], | ||
out_pil_tensor[0, :7, :7] | ||
0.05, | ||
msg="{}: {}\n{} vs \n{}".format( | ||
(f, r, dt, spoints, epoints), | ||
ratio_diff_pixels, | ||
out_tensor[0, :7, :7], | ||
out_pil_tensor[0, :7, :7] | ||
) | ||
) | ||
) | ||
|
||
def test_perspective(self): | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -557,7 +557,7 @@ def perspective( | |
startpoints: List[List[int]], | ||
endpoints: List[List[int]], | ||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||
fill: Optional[int] = None | ||
fill: Optional[List[float]] = None | ||
) -> Tensor: | ||
"""Perform perspective transform of the given image. | ||
The image can be a PIL Image or a Tensor, in which case it is expected | ||
|
@@ -573,10 +573,9 @@ def perspective( | |
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. | ||
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. | ||
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. | ||
fill (n-tuple or int or float): Pixel fill value for area outside the rotated | ||
fill (sequence or int or float, optional): Pixel fill value for area outside the rotated | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make the docstring more explicit about how it works for PIL, tensor and torchscript: fill (sequence or int or float, optional): Pixel fill value for the area outside the rotated
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
In torchscript mode single int/float value is not supported, please use a tuple
or list of length 1: ``[value, ]``.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. same for other docstrings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vfdev-5 I did some sorting. It seems the current version means affine functions support: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought PIL and Tensor could support the same types. Where is actually the problem with tuple if input is Tensor ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not very sure, let me test for it a bit. I wrote the code with only list in mind. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, torch.tensor() can also convert tuples. There is no functional mismatch between tensor and PIL. |
||
image. If int or float, the value is used for all bands respectively. | ||
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor | ||
input. Fill value for the area outside the transform in the output image is always 0. | ||
This option is only available for ``pillow>=5.0.0``. | ||
|
||
Returns: | ||
PIL Image or Tensor: transformed Image. | ||
|
@@ -871,7 +870,7 @@ def _get_inverse_affine_matrix( | |
def rotate( | ||
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, | ||
expand: bool = False, center: Optional[List[int]] = None, | ||
fill: Optional[int] = None, resample: Optional[int] = None | ||
fill: Optional[List[float]] = None, resample: Optional[int] = None | ||
) -> Tensor: | ||
"""Rotate the image by angle. | ||
The image can be a PIL Image or a Tensor, in which case it is expected | ||
|
@@ -890,13 +889,9 @@ def rotate( | |
Note that the expand flag assumes rotation around the center and no translation. | ||
center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. | ||
Default is the center of the image. | ||
fill (n-tuple or int or float): Pixel fill value for area outside the rotated | ||
fill (sequence or int or float, optional): Pixel fill value for area outside the rotated | ||
image. If int or float, the value is used for all bands respectively. | ||
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. | ||
This option is not supported for Tensor input. Fill value for the area outside the transform in the output | ||
image is always 0. | ||
resample (int, optional): deprecated argument and will be removed since v0.10.0. | ||
Please use `arg`:interpolation: instead. | ||
|
||
Returns: | ||
PIL Image or Tensor: Rotated image. | ||
|
@@ -945,8 +940,8 @@ def rotate( | |
|
||
def affine( | ||
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], | ||
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None, | ||
resample: Optional[int] = None, fillcolor: Optional[int] = None | ||
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, | ||
resample: Optional[int] = None, fillcolor: Optional[List[float]] = None | ||
) -> Tensor: | ||
"""Apply affine transformation on the image keeping image center invariant. | ||
The image can be a PIL Image or a Tensor, in which case it is expected | ||
|
@@ -964,10 +959,9 @@ def affine( | |
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. | ||
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. | ||
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. | ||
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). | ||
This option is not supported for Tensor input. Fill value for the area outside the transform in the output | ||
image is always 0. | ||
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. | ||
fill (sequence, int, float): Optional fill color for the area outside the transform in the output image | ||
(Pillow>=5.0.0). | ||
fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0. | ||
Please use `arg`:fill: instead. | ||
resample (int, optional): deprecated argument and will be removed since v0.10.0. | ||
Please use `arg`:interpolation: instead. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,10 +465,13 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"): | |
fill = 0 | ||
if isinstance(fill, (int, float)) and num_bands > 1: | ||
fill = tuple([fill] * num_bands) | ||
if not isinstance(fill, (int, float)) and len(fill) != num_bands: | ||
msg = ("The number of elements in 'fill' does not match the number of " | ||
"bands of the image ({} != {})") | ||
raise ValueError(msg.format(len(fill), num_bands)) | ||
if not isinstance(fill, (int, float)): | ||
if len(fill) != num_bands: | ||
msg = ("The number of elements in 'fill' does not match the number of " | ||
"bands of the image ({} != {})") | ||
raise ValueError(msg.format(len(fill), num_bands)) | ||
else: | ||
fill = tuple(fill) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @voldemortX seems like I missed this modification in the previous review. Why do we need to modify the code here ? It's a PIL side and I think it can be kept as is ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't quite remember changing that actually... I'll get it back to what it was. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I remember now. I need the tuple(fill) conversion since now the input is formatted as List, also that means I can't do tests with tuple fill inputs. @vfdev-5 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. You are right ! Pillow accepts only tuples and we can give a list too. Maybe, we can make the check more straightforward : if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! I'll do that together in the next commit. |
||
|
||
return {name: fill} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a test here with a single int and float as fill value as
[a_int, ]
and(b_float, )
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! I'll add the tests and docstring changes together in the next commit.