From 43b3c6af5da9b49dff3465116105c2cdc4a78445 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 3 Feb 2023 16:55:48 +0100 Subject: [PATCH 1/2] remove default value from LabelToOneHot --- torchvision/prototype/transforms/_type_conversion.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index c84aee62afe..7423c2d085f 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -15,15 +15,12 @@ class LabelToOneHot(Transform): _transformed_types = (datapoints.Label,) - def __init__(self, num_categories: int = -1): + def __init__(self, num_categories: int): super().__init__() self.num_categories = num_categories def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: - num_categories = self.num_categories - if num_categories == -1 and inpt.categories is not None: - num_categories = len(inpt.categories) - output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) + output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=self.num_categories) return datapoints.OneHotLabel(output, categories=inpt.categories) def extra_repr(self) -> str: From 31c7242eae912ac3623720aa373ed2e4927e2413 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 6 Feb 2023 16:32:18 +0100 Subject: [PATCH 2/2] allow None and disallow non-positive integers --- .../prototype/transforms/_type_conversion.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 7423c2d085f..18a7e51540e 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, Optional, Union import numpy as np @@ -15,12 +16,33 @@ class LabelToOneHot(Transform): _transformed_types = (datapoints.Label,) - def __init__(self, num_categories: int): + def __init__(self, num_categories: Optional[int] = None): super().__init__() + if not ((isinstance(num_categories, int) and num_categories > 0) or num_categories is None): + raise ValueError( + f"`num_categories` can either be a positive integer or `None`, but got {num_categories} instead." + ) self.num_categories = num_categories def _transform(self, inpt: datapoints.Label, params: Dict[str, Any]) -> datapoints.OneHotLabel: - output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=self.num_categories) + if self.num_categories is None and inpt.categories is None: + raise RuntimeError( + "Can't determine the number of categories, " + "since neither `num_categories` on this transform, nor the `.categories` attribute on the label is set!" + ) + elif inpt.categories is None: + num_categories = self.num_categories + elif self.num_categories is None: + num_categories = len(inpt.categories) + else: + num_categories = self.num_categories + if num_categories != len(inpt.categories): + warnings.warn( + f"`num_categories` set on this transform mismatches the `.categories` attribute on the label: " + f"{num_categories} != {len(inpt.categories)}" + ) + + output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) return datapoints.OneHotLabel(output, categories=inpt.categories) def extra_repr(self) -> str: