@@ -308,22 +308,28 @@ def __init__(
308
308
ArgsKwargs (brightness = 0.1 , contrast = 0.4 , saturation = 0.7 , hue = 0.3 ),
309
309
],
310
310
),
311
- ConsistencyConfig (
312
- prototype_transforms .ElasticTransform ,
313
- legacy_transforms .ElasticTransform ,
314
- [
315
- ArgsKwargs (),
316
- ArgsKwargs (alpha = 20.0 ),
317
- ArgsKwargs (alpha = (15.3 , 27.2 )),
318
- ArgsKwargs (sigma = 3.0 ),
319
- ArgsKwargs (sigma = (2.5 , 3.9 )),
320
- ArgsKwargs (interpolation = prototype_transforms .InterpolationMode .NEAREST ),
321
- ArgsKwargs (interpolation = prototype_transforms .InterpolationMode .BICUBIC ),
322
- ArgsKwargs (fill = 1 ),
323
- ],
324
- # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
325
- make_images_kwargs = dict (DEFAULT_MAKE_IMAGES_KWARGS , sizes = [(163 , 163 ), (72 , 333 ), (313 , 95 )]),
326
- ),
311
+ * [
312
+ ConsistencyConfig (
313
+ prototype_transforms .ElasticTransform ,
314
+ legacy_transforms .ElasticTransform ,
315
+ [
316
+ ArgsKwargs (),
317
+ ArgsKwargs (alpha = 20.0 ),
318
+ ArgsKwargs (alpha = (15.3 , 27.2 )),
319
+ ArgsKwargs (sigma = 3.0 ),
320
+ ArgsKwargs (sigma = (2.5 , 3.9 )),
321
+ ArgsKwargs (interpolation = prototype_transforms .InterpolationMode .NEAREST ),
322
+ ArgsKwargs (interpolation = prototype_transforms .InterpolationMode .BICUBIC ),
323
+ ArgsKwargs (fill = 1 ),
324
+ ],
325
+ # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
326
+ make_images_kwargs = dict (DEFAULT_MAKE_IMAGES_KWARGS , sizes = [(163 , 163 ), (72 , 333 ), (313 , 95 )], dtypes = [dt ]),
327
+ # We updated gaussian blur kernel generation with a faster and numerically more stable version
328
+ # This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
329
+ closeness_kwargs = ckw ,
330
+ )
331
+ for dt , ckw in [(torch .uint8 , {"rtol" : 1e-1 , "atol" : 1 }), (torch .float32 , {"rtol" : 1e-2 , "atol" : 1e-3 })]
332
+ ],
327
333
ConsistencyConfig (
328
334
prototype_transforms .GaussianBlur ,
329
335
legacy_transforms .GaussianBlur ,
@@ -333,6 +339,7 @@ def __init__(
333
339
ArgsKwargs (kernel_size = 3 , sigma = 0.7 ),
334
340
ArgsKwargs (kernel_size = 5 , sigma = (0.3 , 1.4 )),
335
341
],
342
+ closeness_kwargs = {"rtol" : 1e-5 , "atol" : 1e-5 },
336
343
),
337
344
ConsistencyConfig (
338
345
prototype_transforms .RandomAffine ,
@@ -506,7 +513,6 @@ def check_call_consistency(
506
513
image_repr = f"[{ tuple (image .shape )} , { str (image .dtype ).rsplit ('.' )[- 1 ]} ]"
507
514
508
515
image_tensor = torch .Tensor (image )
509
-
510
516
try :
511
517
torch .manual_seed (0 )
512
518
output_legacy_tensor = legacy_transform (image_tensor )
0 commit comments