@@ -324,85 +324,6 @@ def test_pad(self):
324
324
325
325
self ._test_fn_on_batch (batch_tensors , F .pad , padding = script_pad , ** kwargs )
326
326
327
- def _test_adjust_fn (self , fn , fn_pil , fn_t , configs , tol = 2.0 + 1e-10 , agg_method = "max" ,
328
- dts = (None , torch .float32 , torch .float64 )):
329
- script_fn = torch .jit .script (fn )
330
- torch .manual_seed (15 )
331
- tensor , pil_img = self ._create_data (26 , 34 , device = self .device )
332
- batch_tensors = self ._create_data_batch (16 , 18 , num_samples = 4 , device = self .device )
333
-
334
- for dt in dts :
335
-
336
- if dt is not None :
337
- tensor = F .convert_image_dtype (tensor , dt )
338
- batch_tensors = F .convert_image_dtype (batch_tensors , dt )
339
-
340
- for config in configs :
341
- adjusted_tensor = fn_t (tensor , ** config )
342
- adjusted_pil = fn_pil (pil_img , ** config )
343
- scripted_result = script_fn (tensor , ** config )
344
- msg = "{}, {}" .format (dt , config )
345
- self .assertEqual (adjusted_tensor .dtype , scripted_result .dtype , msg = msg )
346
- self .assertEqual (adjusted_tensor .size ()[1 :], adjusted_pil .size [::- 1 ], msg = msg )
347
-
348
- rbg_tensor = adjusted_tensor
349
-
350
- if adjusted_tensor .dtype != torch .uint8 :
351
- rbg_tensor = F .convert_image_dtype (adjusted_tensor , torch .uint8 )
352
-
353
- # Check that max difference does not exceed 2 in [0, 255] range
354
- # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
355
- self .approxEqualTensorToPIL (rbg_tensor .float (), adjusted_pil , tol = tol , msg = msg , agg_method = agg_method )
356
-
357
- atol = 1e-6
358
- if adjusted_tensor .dtype == torch .uint8 and "cuda" in torch .device (self .device ).type :
359
- atol = 1.0
360
- self .assertTrue (adjusted_tensor .allclose (scripted_result , atol = atol ), msg = msg )
361
-
362
- self ._test_fn_on_batch (batch_tensors , fn , scripted_fn_atol = atol , ** config )
363
-
364
- def test_adjust_brightness (self ):
365
- self ._test_adjust_fn (
366
- F .adjust_brightness ,
367
- F_pil .adjust_brightness ,
368
- F_t .adjust_brightness ,
369
- [{"brightness_factor" : f } for f in [0.1 , 0.5 , 1.0 , 1.34 , 2.5 ]]
370
- )
371
-
372
- def test_adjust_contrast (self ):
373
- self ._test_adjust_fn (
374
- F .adjust_contrast ,
375
- F_pil .adjust_contrast ,
376
- F_t .adjust_contrast ,
377
- [{"contrast_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]]
378
- )
379
-
380
- def test_adjust_saturation (self ):
381
- self ._test_adjust_fn (
382
- F .adjust_saturation ,
383
- F_pil .adjust_saturation ,
384
- F_t .adjust_saturation ,
385
- [{"saturation_factor" : f } for f in [0.5 , 0.75 , 1.0 , 1.5 , 2.0 ]]
386
- )
387
-
388
- def test_adjust_hue (self ):
389
- self ._test_adjust_fn (
390
- F .adjust_hue ,
391
- F_pil .adjust_hue ,
392
- F_t .adjust_hue ,
393
- [{"hue_factor" : f } for f in [- 0.45 , - 0.25 , 0.0 , 0.25 , 0.45 ]],
394
- tol = 16.1 ,
395
- agg_method = "max"
396
- )
397
-
398
- def test_adjust_gamma (self ):
399
- self ._test_adjust_fn (
400
- F .adjust_gamma ,
401
- F_pil .adjust_gamma ,
402
- F_t .adjust_gamma ,
403
- [{"gamma" : g1 , "gain" : g2 } for g1 , g2 in zip ([0.8 , 1.0 , 1.2 ], [0.7 , 1.0 , 1.3 ])]
404
- )
405
-
406
327
def test_resize (self ):
407
328
script_fn = torch .jit .script (F .resize )
408
329
tensor , pil_img = self ._create_data (26 , 36 , device = self .device )
@@ -833,77 +754,6 @@ def test_gaussian_blur(self):
833
754
msg = "{}, {}" .format (ksize , sigma )
834
755
)
835
756
836
- def test_invert (self ):
837
- self ._test_adjust_fn (
838
- F .invert ,
839
- F_pil .invert ,
840
- F_t .invert ,
841
- [{}],
842
- tol = 1.0 ,
843
- agg_method = "max"
844
- )
845
-
846
- def test_posterize (self ):
847
- self ._test_adjust_fn (
848
- F .posterize ,
849
- F_pil .posterize ,
850
- F_t .posterize ,
851
- [{"bits" : bits } for bits in range (0 , 8 )],
852
- tol = 1.0 ,
853
- agg_method = "max" ,
854
- dts = (None ,)
855
- )
856
-
857
- def test_solarize (self ):
858
- self ._test_adjust_fn (
859
- F .solarize ,
860
- F_pil .solarize ,
861
- F_t .solarize ,
862
- [{"threshold" : threshold } for threshold in [0 , 64 , 128 , 192 , 255 ]],
863
- tol = 1.0 ,
864
- agg_method = "max" ,
865
- dts = (None ,)
866
- )
867
- self ._test_adjust_fn (
868
- F .solarize ,
869
- lambda img , threshold : F_pil .solarize (img , 255 * threshold ),
870
- F_t .solarize ,
871
- [{"threshold" : threshold } for threshold in [0.0 , 0.25 , 0.5 , 0.75 , 1.0 ]],
872
- tol = 1.0 ,
873
- agg_method = "max" ,
874
- dts = (torch .float32 , torch .float64 )
875
- )
876
-
877
- def test_adjust_sharpness (self ):
878
- self ._test_adjust_fn (
879
- F .adjust_sharpness ,
880
- F_pil .adjust_sharpness ,
881
- F_t .adjust_sharpness ,
882
- [{"sharpness_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]]
883
- )
884
-
885
- def test_autocontrast (self ):
886
- self ._test_adjust_fn (
887
- F .autocontrast ,
888
- F_pil .autocontrast ,
889
- F_t .autocontrast ,
890
- [{}],
891
- tol = 1.0 ,
892
- agg_method = "max"
893
- )
894
-
895
- def test_equalize (self ):
896
- torch .set_deterministic (False )
897
- self ._test_adjust_fn (
898
- F .equalize ,
899
- F_pil .equalize ,
900
- F_t .equalize ,
901
- [{}],
902
- tol = 1.0 ,
903
- agg_method = "max" ,
904
- dts = (None ,)
905
- )
906
-
907
757
908
758
@unittest .skipIf (not torch .cuda .is_available (), reason = "Skip if no CUDA device" )
909
759
class CUDATester (Tester ):
@@ -1074,5 +924,219 @@ def test_resize_antialias(device, dt, size, interpolation, tester):
1074
924
tester .assertTrue (resized_tensor .equal (resize_result ), msg = f"{ size } , { interpolation } , { dt } " )
1075
925
1076
926
927
+ def check_functional_vs_PIL_vs_scripted (fn , fn_pil , fn_t , config , device , dtype , tol = 2.0 + 1e-10 , agg_method = "max" ):
928
+
929
+ tester = Tester ()
930
+
931
+ script_fn = torch .jit .script (fn )
932
+ torch .manual_seed (15 )
933
+ tensor , pil_img = tester ._create_data (26 , 34 , device = device )
934
+ batch_tensors = tester ._create_data_batch (16 , 18 , num_samples = 4 , device = device )
935
+
936
+ if dtype is not None :
937
+ tensor = F .convert_image_dtype (tensor , dtype )
938
+ batch_tensors = F .convert_image_dtype (batch_tensors , dtype )
939
+
940
+ out_fn_t = fn_t (tensor , ** config )
941
+ out_pil = fn_pil (pil_img , ** config )
942
+ out_scripted = script_fn (tensor , ** config )
943
+ assert out_fn_t .dtype == out_scripted .dtype
944
+ assert out_fn_t .size ()[1 :] == out_pil .size [::- 1 ]
945
+
946
+ rbg_tensor = out_fn_t
947
+
948
+ if out_fn_t .dtype != torch .uint8 :
949
+ rbg_tensor = F .convert_image_dtype (out_fn_t , torch .uint8 )
950
+
951
+ # Check that max difference does not exceed 2 in [0, 255] range
952
+ # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
953
+ tester .approxEqualTensorToPIL (rbg_tensor .float (), out_pil , tol = tol , agg_method = agg_method )
954
+
955
+ atol = 1e-6
956
+ if out_fn_t .dtype == torch .uint8 and "cuda" in torch .device (device ).type :
957
+ atol = 1.0
958
+ assert out_fn_t .allclose (out_scripted , atol = atol )
959
+
960
+ # FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
961
+ tester ._test_fn_on_batch (batch_tensors , fn , scripted_fn_atol = atol , ** config )
962
+
963
+
964
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
965
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
966
+ @pytest .mark .parametrize ('config' , [{"brightness_factor" : f } for f in (0.1 , 0.5 , 1.0 , 1.34 , 2.5 )])
967
+ def test_adjust_brightness (device , dtype , config ):
968
+ check_functional_vs_PIL_vs_scripted (
969
+ F .adjust_brightness ,
970
+ F_pil .adjust_brightness ,
971
+ F_t .adjust_brightness ,
972
+ config ,
973
+ device ,
974
+ dtype ,
975
+ )
976
+
977
+
978
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
979
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
980
+ def test_invert (device , dtype ):
981
+ check_functional_vs_PIL_vs_scripted (
982
+ F .invert ,
983
+ F_pil .invert ,
984
+ F_t .invert ,
985
+ {},
986
+ device ,
987
+ dtype ,
988
+ tol = 1.0 ,
989
+ agg_method = "max"
990
+ )
991
+
992
+
993
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
994
+ @pytest .mark .parametrize ('config' , [{"bits" : bits } for bits in range (0 , 8 )])
995
+ def test_posterize (device , config ):
996
+ check_functional_vs_PIL_vs_scripted (
997
+ F .posterize ,
998
+ F_pil .posterize ,
999
+ F_t .posterize ,
1000
+ config ,
1001
+ device ,
1002
+ dtype = None ,
1003
+ tol = 1.0 ,
1004
+ agg_method = "max" ,
1005
+ )
1006
+
1007
+
1008
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1009
+ @pytest .mark .parametrize ('config' , [{"threshold" : threshold } for threshold in [0 , 64 , 128 , 192 , 255 ]])
1010
+ def test_solarize1 (device , config ):
1011
+ check_functional_vs_PIL_vs_scripted (
1012
+ F .solarize ,
1013
+ F_pil .solarize ,
1014
+ F_t .solarize ,
1015
+ config ,
1016
+ device ,
1017
+ dtype = None ,
1018
+ tol = 1.0 ,
1019
+ agg_method = "max" ,
1020
+ )
1021
+
1022
+
1023
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1024
+ @pytest .mark .parametrize ('dtype' , (torch .float32 , torch .float64 ))
1025
+ @pytest .mark .parametrize ('config' , [{"threshold" : threshold } for threshold in [0.0 , 0.25 , 0.5 , 0.75 , 1.0 ]])
1026
+ def test_solarize2 (device , dtype , config ):
1027
+ check_functional_vs_PIL_vs_scripted (
1028
+ F .solarize ,
1029
+ lambda img , threshold : F_pil .solarize (img , 255 * threshold ),
1030
+ F_t .solarize ,
1031
+ config ,
1032
+ device ,
1033
+ dtype ,
1034
+ tol = 1.0 ,
1035
+ agg_method = "max" ,
1036
+ )
1037
+
1038
+
1039
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1040
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
1041
+ @pytest .mark .parametrize ('config' , [{"sharpness_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]])
1042
+ def test_adjust_sharpness (device , dtype , config ):
1043
+ check_functional_vs_PIL_vs_scripted (
1044
+ F .adjust_sharpness ,
1045
+ F_pil .adjust_sharpness ,
1046
+ F_t .adjust_sharpness ,
1047
+ config ,
1048
+ device ,
1049
+ dtype ,
1050
+ )
1051
+
1052
+
1053
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1054
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
1055
+ def test_autocontrast (device , dtype ):
1056
+ check_functional_vs_PIL_vs_scripted (
1057
+ F .autocontrast ,
1058
+ F_pil .autocontrast ,
1059
+ F_t .autocontrast ,
1060
+ {},
1061
+ device ,
1062
+ dtype ,
1063
+ tol = 1.0 ,
1064
+ agg_method = "max"
1065
+ )
1066
+
1067
+
1068
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1069
+ def test_equalize (device ):
1070
+ torch .set_deterministic (False )
1071
+ check_functional_vs_PIL_vs_scripted (
1072
+ F .equalize ,
1073
+ F_pil .equalize ,
1074
+ F_t .equalize ,
1075
+ {},
1076
+ device ,
1077
+ dtype = None ,
1078
+ tol = 1.0 ,
1079
+ agg_method = "max" ,
1080
+ )
1081
+
1082
+
1083
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1084
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
1085
+ @pytest .mark .parametrize ('config' , [{"contrast_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]])
1086
+ def test_adjust_contrast (device , dtype , config ):
1087
+ check_functional_vs_PIL_vs_scripted (
1088
+ F .adjust_contrast ,
1089
+ F_pil .adjust_contrast ,
1090
+ F_t .adjust_contrast ,
1091
+ config ,
1092
+ device ,
1093
+ dtype
1094
+ )
1095
+
1096
+
1097
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1098
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
1099
+ @pytest .mark .parametrize ('config' , [{"saturation_factor" : f } for f in [0.5 , 0.75 , 1.0 , 1.5 , 2.0 ]])
1100
+ def test_adjust_saturation (device , dtype , config ):
1101
+ check_functional_vs_PIL_vs_scripted (
1102
+ F .adjust_saturation ,
1103
+ F_pil .adjust_saturation ,
1104
+ F_t .adjust_saturation ,
1105
+ config ,
1106
+ device ,
1107
+ dtype
1108
+ )
1109
+
1110
+
1111
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1112
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
1113
+ @pytest .mark .parametrize ('config' , [{"hue_factor" : f } for f in [- 0.45 , - 0.25 , 0.0 , 0.25 , 0.45 ]])
1114
+ def test_adjust_hue (device , dtype , config ):
1115
+ check_functional_vs_PIL_vs_scripted (
1116
+ F .adjust_hue ,
1117
+ F_pil .adjust_hue ,
1118
+ F_t .adjust_hue ,
1119
+ config ,
1120
+ device ,
1121
+ dtype ,
1122
+ tol = 16.1 ,
1123
+ agg_method = "max"
1124
+ )
1125
+
1126
+
1127
+ @pytest .mark .parametrize ('device' , cpu_and_gpu ())
1128
+ @pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
1129
+ @pytest .mark .parametrize ('config' , [{"gamma" : g1 , "gain" : g2 } for g1 , g2 in zip ([0.8 , 1.0 , 1.2 ], [0.7 , 1.0 , 1.3 ])])
1130
+ def test_adjust_gamma (device , dtype , config ):
1131
+ check_functional_vs_PIL_vs_scripted (
1132
+ F .adjust_gamma ,
1133
+ F_pil .adjust_gamma ,
1134
+ F_t .adjust_gamma ,
1135
+ config ,
1136
+ device ,
1137
+ dtype ,
1138
+ )
1139
+
1140
+
1077
1141
if __name__ == '__main__' :
1078
1142
unittest .main ()
0 commit comments