Skip to content

Commit 32bccc5

Browse files
authored
Port _test_adjust_fn to pytest (#3845)
1 parent 0fece1f commit 32bccc5

File tree

1 file changed

+214
-150
lines changed

1 file changed

+214
-150
lines changed

test/test_functional_tensor.py

Lines changed: 214 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -324,85 +324,6 @@ def test_pad(self):
324324

325325
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
326326

327-
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
328-
dts=(None, torch.float32, torch.float64)):
329-
script_fn = torch.jit.script(fn)
330-
torch.manual_seed(15)
331-
tensor, pil_img = self._create_data(26, 34, device=self.device)
332-
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
333-
334-
for dt in dts:
335-
336-
if dt is not None:
337-
tensor = F.convert_image_dtype(tensor, dt)
338-
batch_tensors = F.convert_image_dtype(batch_tensors, dt)
339-
340-
for config in configs:
341-
adjusted_tensor = fn_t(tensor, **config)
342-
adjusted_pil = fn_pil(pil_img, **config)
343-
scripted_result = script_fn(tensor, **config)
344-
msg = "{}, {}".format(dt, config)
345-
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
346-
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
347-
348-
rbg_tensor = adjusted_tensor
349-
350-
if adjusted_tensor.dtype != torch.uint8:
351-
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
352-
353-
# Check that max difference does not exceed 2 in [0, 255] range
354-
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
355-
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)
356-
357-
atol = 1e-6
358-
if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
359-
atol = 1.0
360-
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
361-
362-
self._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
363-
364-
def test_adjust_brightness(self):
365-
self._test_adjust_fn(
366-
F.adjust_brightness,
367-
F_pil.adjust_brightness,
368-
F_t.adjust_brightness,
369-
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
370-
)
371-
372-
def test_adjust_contrast(self):
373-
self._test_adjust_fn(
374-
F.adjust_contrast,
375-
F_pil.adjust_contrast,
376-
F_t.adjust_contrast,
377-
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
378-
)
379-
380-
def test_adjust_saturation(self):
381-
self._test_adjust_fn(
382-
F.adjust_saturation,
383-
F_pil.adjust_saturation,
384-
F_t.adjust_saturation,
385-
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
386-
)
387-
388-
def test_adjust_hue(self):
389-
self._test_adjust_fn(
390-
F.adjust_hue,
391-
F_pil.adjust_hue,
392-
F_t.adjust_hue,
393-
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
394-
tol=16.1,
395-
agg_method="max"
396-
)
397-
398-
def test_adjust_gamma(self):
399-
self._test_adjust_fn(
400-
F.adjust_gamma,
401-
F_pil.adjust_gamma,
402-
F_t.adjust_gamma,
403-
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
404-
)
405-
406327
def test_resize(self):
407328
script_fn = torch.jit.script(F.resize)
408329
tensor, pil_img = self._create_data(26, 36, device=self.device)
@@ -833,77 +754,6 @@ def test_gaussian_blur(self):
833754
msg="{}, {}".format(ksize, sigma)
834755
)
835756

836-
def test_invert(self):
837-
self._test_adjust_fn(
838-
F.invert,
839-
F_pil.invert,
840-
F_t.invert,
841-
[{}],
842-
tol=1.0,
843-
agg_method="max"
844-
)
845-
846-
def test_posterize(self):
847-
self._test_adjust_fn(
848-
F.posterize,
849-
F_pil.posterize,
850-
F_t.posterize,
851-
[{"bits": bits} for bits in range(0, 8)],
852-
tol=1.0,
853-
agg_method="max",
854-
dts=(None,)
855-
)
856-
857-
def test_solarize(self):
858-
self._test_adjust_fn(
859-
F.solarize,
860-
F_pil.solarize,
861-
F_t.solarize,
862-
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
863-
tol=1.0,
864-
agg_method="max",
865-
dts=(None,)
866-
)
867-
self._test_adjust_fn(
868-
F.solarize,
869-
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
870-
F_t.solarize,
871-
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
872-
tol=1.0,
873-
agg_method="max",
874-
dts=(torch.float32, torch.float64)
875-
)
876-
877-
def test_adjust_sharpness(self):
878-
self._test_adjust_fn(
879-
F.adjust_sharpness,
880-
F_pil.adjust_sharpness,
881-
F_t.adjust_sharpness,
882-
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
883-
)
884-
885-
def test_autocontrast(self):
886-
self._test_adjust_fn(
887-
F.autocontrast,
888-
F_pil.autocontrast,
889-
F_t.autocontrast,
890-
[{}],
891-
tol=1.0,
892-
agg_method="max"
893-
)
894-
895-
def test_equalize(self):
896-
torch.set_deterministic(False)
897-
self._test_adjust_fn(
898-
F.equalize,
899-
F_pil.equalize,
900-
F_t.equalize,
901-
[{}],
902-
tol=1.0,
903-
agg_method="max",
904-
dts=(None,)
905-
)
906-
907757

908758
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
909759
class CUDATester(Tester):
@@ -1074,5 +924,219 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
1074924
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
1075925

1076926

927+
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
928+
929+
tester = Tester()
930+
931+
script_fn = torch.jit.script(fn)
932+
torch.manual_seed(15)
933+
tensor, pil_img = tester._create_data(26, 34, device=device)
934+
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device)
935+
936+
if dtype is not None:
937+
tensor = F.convert_image_dtype(tensor, dtype)
938+
batch_tensors = F.convert_image_dtype(batch_tensors, dtype)
939+
940+
out_fn_t = fn_t(tensor, **config)
941+
out_pil = fn_pil(pil_img, **config)
942+
out_scripted = script_fn(tensor, **config)
943+
assert out_fn_t.dtype == out_scripted.dtype
944+
assert out_fn_t.size()[1:] == out_pil.size[::-1]
945+
946+
rbg_tensor = out_fn_t
947+
948+
if out_fn_t.dtype != torch.uint8:
949+
rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8)
950+
951+
# Check that max difference does not exceed 2 in [0, 255] range
952+
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
953+
tester.approxEqualTensorToPIL(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
954+
955+
atol = 1e-6
956+
if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
957+
atol = 1.0
958+
assert out_fn_t.allclose(out_scripted, atol=atol)
959+
960+
# FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
961+
tester._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
962+
963+
964+
@pytest.mark.parametrize('device', cpu_and_gpu())
965+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
966+
@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
967+
def test_adjust_brightness(device, dtype, config):
968+
check_functional_vs_PIL_vs_scripted(
969+
F.adjust_brightness,
970+
F_pil.adjust_brightness,
971+
F_t.adjust_brightness,
972+
config,
973+
device,
974+
dtype,
975+
)
976+
977+
978+
@pytest.mark.parametrize('device', cpu_and_gpu())
979+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
980+
def test_invert(device, dtype):
981+
check_functional_vs_PIL_vs_scripted(
982+
F.invert,
983+
F_pil.invert,
984+
F_t.invert,
985+
{},
986+
device,
987+
dtype,
988+
tol=1.0,
989+
agg_method="max"
990+
)
991+
992+
993+
@pytest.mark.parametrize('device', cpu_and_gpu())
994+
@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)])
995+
def test_posterize(device, config):
996+
check_functional_vs_PIL_vs_scripted(
997+
F.posterize,
998+
F_pil.posterize,
999+
F_t.posterize,
1000+
config,
1001+
device,
1002+
dtype=None,
1003+
tol=1.0,
1004+
agg_method="max",
1005+
)
1006+
1007+
1008+
@pytest.mark.parametrize('device', cpu_and_gpu())
1009+
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
1010+
def test_solarize1(device, config):
1011+
check_functional_vs_PIL_vs_scripted(
1012+
F.solarize,
1013+
F_pil.solarize,
1014+
F_t.solarize,
1015+
config,
1016+
device,
1017+
dtype=None,
1018+
tol=1.0,
1019+
agg_method="max",
1020+
)
1021+
1022+
1023+
@pytest.mark.parametrize('device', cpu_and_gpu())
1024+
@pytest.mark.parametrize('dtype', (torch.float32, torch.float64))
1025+
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
1026+
def test_solarize2(device, dtype, config):
1027+
check_functional_vs_PIL_vs_scripted(
1028+
F.solarize,
1029+
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
1030+
F_t.solarize,
1031+
config,
1032+
device,
1033+
dtype,
1034+
tol=1.0,
1035+
agg_method="max",
1036+
)
1037+
1038+
1039+
@pytest.mark.parametrize('device', cpu_and_gpu())
1040+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
1041+
@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
1042+
def test_adjust_sharpness(device, dtype, config):
1043+
check_functional_vs_PIL_vs_scripted(
1044+
F.adjust_sharpness,
1045+
F_pil.adjust_sharpness,
1046+
F_t.adjust_sharpness,
1047+
config,
1048+
device,
1049+
dtype,
1050+
)
1051+
1052+
1053+
@pytest.mark.parametrize('device', cpu_and_gpu())
1054+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
1055+
def test_autocontrast(device, dtype):
1056+
check_functional_vs_PIL_vs_scripted(
1057+
F.autocontrast,
1058+
F_pil.autocontrast,
1059+
F_t.autocontrast,
1060+
{},
1061+
device,
1062+
dtype,
1063+
tol=1.0,
1064+
agg_method="max"
1065+
)
1066+
1067+
1068+
@pytest.mark.parametrize('device', cpu_and_gpu())
1069+
def test_equalize(device):
1070+
torch.set_deterministic(False)
1071+
check_functional_vs_PIL_vs_scripted(
1072+
F.equalize,
1073+
F_pil.equalize,
1074+
F_t.equalize,
1075+
{},
1076+
device,
1077+
dtype=None,
1078+
tol=1.0,
1079+
agg_method="max",
1080+
)
1081+
1082+
1083+
@pytest.mark.parametrize('device', cpu_and_gpu())
1084+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
1085+
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
1086+
def test_adjust_contrast(device, dtype, config):
1087+
check_functional_vs_PIL_vs_scripted(
1088+
F.adjust_contrast,
1089+
F_pil.adjust_contrast,
1090+
F_t.adjust_contrast,
1091+
config,
1092+
device,
1093+
dtype
1094+
)
1095+
1096+
1097+
@pytest.mark.parametrize('device', cpu_and_gpu())
1098+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
1099+
@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
1100+
def test_adjust_saturation(device, dtype, config):
1101+
check_functional_vs_PIL_vs_scripted(
1102+
F.adjust_saturation,
1103+
F_pil.adjust_saturation,
1104+
F_t.adjust_saturation,
1105+
config,
1106+
device,
1107+
dtype
1108+
)
1109+
1110+
1111+
@pytest.mark.parametrize('device', cpu_and_gpu())
1112+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
1113+
@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
1114+
def test_adjust_hue(device, dtype, config):
1115+
check_functional_vs_PIL_vs_scripted(
1116+
F.adjust_hue,
1117+
F_pil.adjust_hue,
1118+
F_t.adjust_hue,
1119+
config,
1120+
device,
1121+
dtype,
1122+
tol=16.1,
1123+
agg_method="max"
1124+
)
1125+
1126+
1127+
@pytest.mark.parametrize('device', cpu_and_gpu())
1128+
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
1129+
@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
1130+
def test_adjust_gamma(device, dtype, config):
1131+
check_functional_vs_PIL_vs_scripted(
1132+
F.adjust_gamma,
1133+
F_pil.adjust_gamma,
1134+
F_t.adjust_gamma,
1135+
config,
1136+
device,
1137+
dtype,
1138+
)
1139+
1140+
10771141
if __name__ == '__main__':
10781142
unittest.main()

0 commit comments

Comments
 (0)