From a17f1726630afa75f19db2e8a346138856d4e5cb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 15 Feb 2023 14:34:33 +0100 Subject: [PATCH] allow integer parameters in ColorJitter --- torchvision/prototype/transforms/_color.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 09e313e5bed..8ac0d857753 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -80,14 +80,16 @@ def _check_input( if value is None: return None - if isinstance(value, float): + if isinstance(value, (int, float)): if value < 0: raise ValueError(f"If {name} is a single number, it must be non negative.") value = [center - value, center + value] if clip_first_on_zero: value[0] = max(value[0], 0.0) - elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2): - raise TypeError(f"{name} should be a single number or a sequence with length 2.") + elif isinstance(value, collections.abc.Sequence) and len(value) == 2: + value = [float(v) for v in value] + else: + raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.") if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError(f"{name} values should be between {bound}, but got {value}.")