Skip to content

Add torchscriptable adjust_gamma transform #2459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def _create_data(self, height=3, width=3, channels=3):

def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
Expand Down Expand Up @@ -293,6 +295,33 @@ def test_pad(self):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

def test_adjust_gamma(self):
script_fn = torch.jit.script(F_t.adjust_gamma)
tensor, pil_img = self._create_data(26, 36)

for dt in [torch.float64, torch.float32, None]:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)

gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):

adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])

rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

self.compareTensorToPIL(rbg_tensor, adjusted_pil)

self.assertTrue(adjusted_tensor.equal(scripted_result))

def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36)
Expand Down
24 changes: 7 additions & 17 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was giving results that didn't match PIL

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, we had quite a lot of discussion about this behavior in #2078 (comment). I believe if we make the multiplication go to dtype.max, we will end up with a non-uniform distribution over the last values.

@pmeier thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa is right about the intention. I think this boils down to

  • do we want it "right" or
  • do we want it compatible to other packages.

I'm in favor for the former (hence my implementation), but I can see why the latter is also feasible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my issue was internal consistency, not a difference between us and PIL. That comment thread is helpful -- we certainly expect more than one floating point value to map to 255.

I was able to fix it by just making the gamma adjustment consistent with convert_image_dtype.

This raises another issue though -- I'm not sure if it's still OK to express the equation for adjust_gamma as 255 * gain * (img/255) ** gamma in the docs, where in reality it's 255.999 * gain..... I want to be accurate but also don't want to be unnecessarily confusing, and the doc does say "based on."

max = torch.iinfo(dtype).max
return image.mul(torch.iinfo(dtype).max).clamp(0, max).to(dtype)
else:
# int to float
if dtype.is_floating_point:
Expand Down Expand Up @@ -760,7 +760,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))


def adjust_gamma(img, gamma, gain=1):
def adjust_gamma(img, gamma: float, gain: float = 1):
r"""Perform gamma correction on an image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
Expand All @@ -774,26 +774,16 @@ def adjust_gamma(img, gamma, gain=1):
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Tensor): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
if F_pil._is_pil_image(img):
return F_pil.adjust_gamma(img, gamma, gain)

input_mode = img.mode
img = img.convert('RGB')

gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part

img = img.convert(input_mode)
return img
return F_t.adjust_gamma(img, gamma, gain)


def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
Expand Down
37 changes: 37 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,43 @@ def adjust_hue(img, hue_factor):
return img


@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
r"""Perform gamma correction on an image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:

.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

See `Gamma Correction`_ for more details.

.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

Args:
img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

input_mode = img.mode
img = img.convert('RGB')

gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part

img = img.convert(input_mode)
return img


@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
r"""Pad the given PIL.Image on all sides with the given "pad" value.
Expand Down
38 changes: 38 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,44 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return _blend(img, rgb_to_grayscale(img), saturation_factor)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Adjust gamma of an RGB image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:

.. math::
`I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`

See `Gamma Correction`_ for more details.

.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

Args:
img (Tensor): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""

if not isinstance(img, torch.Tensor):
raise TypeError('img should be a Tensor. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

result = img
dtype = img.dtype
if torch.is_floating_point(img):
return gain * result ** gamma

result = 255.0 * gain * (result / 255.0) ** gamma
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future: we should support other integer types as well. This could probably be implemented via

convert_image_dtype(..., float)
result = torch.clamp(gain * result ** gama, 0, 1)
convert_image_dtype(..., dtype)

# PIL clamps, to(torch.uint8) would wrap
result = result.clamp(0, 255).to(dtype)
return result


def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""Crop the Image Tensor and resize it to desired size.

Expand Down