From 97a5e4619c064b51a8cc503621338cb3707cdb5b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 11 Nov 2023 20:09:31 -0700 Subject: [PATCH 1/8] Refactor resampling. 1. Rename to Resampler from ResampleGrouper 2. Move code from common.resample to TimeResampler --- xarray/core/common.py | 42 ++++------------ xarray/core/groupby.py | 108 +++++++++++++++++++++++++++++------------ 2 files changed, 86 insertions(+), 64 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index cf2b4063202..0d71f1a7d55 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -16,7 +16,6 @@ from xarray.core.utils import ( Frozen, either_dict_or_kwargs, - emit_user_level_warning, is_scalar, ) from xarray.namedarray.core import _raise_if_any_duplicate_dimensions @@ -1050,8 +1049,7 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import ResolvedTimeResampleGrouper, TimeResampleGrouper - from xarray.core.pdcompat import _convert_base_to_offset + from xarray.core.groupby import ResolvedTimeResampler, TimeResampler from xarray.core.resample import RESAMPLE_DIM # note: the second argument (now 'skipna') use to be 'dim' @@ -1074,44 +1072,24 @@ def _resample( dim_name: Hashable = dim dim_coord = self[dim] - if loffset is not None: - emit_user_level_warning( - "Following pandas, the `loffset` parameter to resample is deprecated. " - "Switch to updating the resampled dataset time coordinate using " - "time offset arithmetic. For example:\n" - " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" - ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', - FutureWarning, - ) - - if base is not None: - emit_user_level_warning( - "Following pandas, the `base` parameter to resample will be deprecated in " - "a future version of xarray. Switch to using `origin` or `offset` instead.", - FutureWarning, - ) - - if base is not None and offset is not None: - raise ValueError("base and offset cannot be present at the same time") - - if base is not None: - index = self._indexes[dim_name].to_pandas_index() - offset = _convert_base_to_offset(base, freq, index) + group = DataArray( + dim_coord, + coords=dim_coord.coords, + dims=dim_coord.dims, + name=RESAMPLE_DIM, + ) - grouper = TimeResampleGrouper( + grouper = TimeResampler( freq=freq, closed=closed, label=label, origin=origin, offset=offset, loffset=loffset, + base=base, ) - group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM - ) - - rgrouper = ResolvedTimeResampleGrouper(grouper, group, self) + rgrouper = ResolvedTimeResampler(grouper, group, self) return resample_cls( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ed6c74bc262..172b70e2497 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -500,44 +500,70 @@ def factorize(self) -> T_FactorizeOut: @dataclass -class ResolvedTimeResampleGrouper(ResolvedGrouper): - grouper: TimeResampleGrouper +class ResolvedTimeResampler(ResolvedGrouper): + grouper: TimeResampler index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) + group_as_index: pd.Index = field(init=False) + + def __post_init__(self): + if self.loffset is not None: + emit_user_level_warning( + "Following pandas, the `loffset` parameter to resample is deprecated. " + "Switch to updating the resampled dataset time coordinate using " + "time offset arithmetic. For example:\n" + " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" + ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', + FutureWarning, + ) - def __post_init__(self) -> None: - super().__post_init__() + if self.base is not None: + emit_user_level_warning( + "Following pandas, the `base` parameter to resample will be deprecated in " + "a future version of xarray. Switch to using `origin` or `offset` instead.", + FutureWarning, + ) + + if self.base is not None and self.offset is not None: + raise ValueError("base and offset cannot be present at the same time") + def _init_properties(self, group): from xarray import CFTimeIndex + from xarray.core.pdcompat import _convert_base_to_offset - group_as_index = safe_cast_to_index(self.group) - self._group_as_index = group_as_index + group_as_index = safe_cast_to_index(group) + + if self.base is not None: + # grouper constructor verifies that grouper.offset is None at this point + offset = _convert_base_to_offset(self.base, self.freq, group_as_index) + else: + offset = self.offset if not group_as_index.is_monotonic_increasing: # TODO: sort instead of raising an error raise ValueError("index must be monotonic for resampling") - grouper = self.grouper if isinstance(group_as_index, CFTimeIndex): from xarray.core.resample_cftime import CFTimeGrouper index_grouper = CFTimeGrouper( - freq=grouper.freq, - closed=grouper.closed, - label=grouper.label, - origin=grouper.origin, - offset=grouper.offset, - loffset=grouper.loffset, + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, + loffset=self.loffset, ) else: index_grouper = pd.Grouper( # TODO remove once requiring pandas >= 2.2 - freq=_new_to_legacy_freq(grouper.freq), - closed=grouper.closed, - label=grouper.label, - origin=grouper.origin, - offset=grouper.offset, + freq=_new_to_legacy_freq(self.freq), + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, ) self.index_grouper = index_grouper + self.group_as_index = group_as_index def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() @@ -562,11 +588,12 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: # So for _flox_reduce we avoid one reindex and copy by avoiding # _maybe_restore_empty_groups codes = np.repeat(np.arange(len(first_items)), counts) - if self.grouper.loffset is not None: - _apply_loffset(self.grouper.loffset, first_items) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) return first_items, codes - def factorize(self) -> T_FactorizeOut: + def _factorize(self, group) -> T_FactorizeOut: + self._init_properties(group) full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) group_indices: T_GroupIndices = [ @@ -574,10 +601,8 @@ def factorize(self) -> T_FactorizeOut: ] group_indices += [slice(sbins[-1], None)] - unique_coord = IndexVariable( - self.group.name, first_items.index, self.group.attrs - ) - codes = self.group.copy(data=codes_) + unique_coord = IndexVariable(group.name, first_items.index, group.attrs) + codes = group.copy(data=codes_) return codes, group_indices, unique_coord, full_index @@ -602,13 +627,32 @@ def __post_init__(self) -> None: @dataclass -class TimeResampleGrouper(Grouper): +class TimeResampler(Grouper): freq: str - closed: SideOptions | None - label: SideOptions | None - origin: str | DatetimeLike | None - offset: pd.Timedelta | datetime.timedelta | str | None - loffset: datetime.timedelta | str | None + closed: SideOptions | None = field(default=None) + label: SideOptions | None = field(default=None) + origin: str | DatetimeLike = field(default="start_day") + offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) + loffset: datetime.timedelta | str | None = field(default=None) + base: str | None = field(default=None) + + def __post_init__(self): + if self.loffset is not None: + emit_user_level_warning( + "Following pandas, the `loffset` parameter to resample will be deprecated " + "in a future version of xarray. Switch to using time offset arithmetic.", + FutureWarning, + ) + + if self.base is not None: + emit_user_level_warning( + "Following pandas, the `base` parameter to resample will be deprecated in " + "a future version of xarray. Switch to using `origin` or `offset` instead.", + FutureWarning, + ) + + if self.base is not None and self.offset is not None: + raise ValueError("base and offset cannot be present at the same time") def _validate_groupby_squeeze(squeeze: bool | None) -> None: @@ -974,7 +1018,7 @@ def _maybe_restore_empty_groups(self, combined): """ (grouper,) = self.groupers if ( - isinstance(grouper, (ResolvedBinGrouper, ResolvedTimeResampleGrouper)) + isinstance(grouper, (ResolvedBinGrouper, ResolvedTimeResampler)) and grouper.name in combined.dims ): indexers = {grouper.name: grouper.full_index} From 589e897cec6bc0bb26afc0b835cafa8e5dd30979 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 9 Nov 2023 20:44:38 -0700 Subject: [PATCH 2/8] Single Grouper class --- xarray/core/common.py | 4 +- xarray/core/dataarray.py | 8 +- xarray/core/dataset.py | 8 +- xarray/core/groupby.py | 183 +++++++++++++++++++-------------------- 4 files changed, 97 insertions(+), 106 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 0d71f1a7d55..7b9a049c662 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1049,7 +1049,7 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import ResolvedTimeResampler, TimeResampler + from xarray.core.groupby import ResolvedGrouper, TimeResampler from xarray.core.resample import RESAMPLE_DIM # note: the second argument (now 'skipna') use to be 'dim' @@ -1089,7 +1089,7 @@ def _resample( base=base, ) - rgrouper = ResolvedTimeResampler(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self) return resample_cls( self, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c00fe1a9e67..2907096d9dd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6704,13 +6704,13 @@ def groupby( """ from xarray.core.groupby import ( DataArrayGroupBy, - ResolvedUniqueGrouper, + ResolvedGrouper, UniqueGrouper, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) return DataArrayGroupBy( self, (rgrouper,), @@ -6789,7 +6789,7 @@ def groupby_bins( from xarray.core.groupby import ( BinGrouper, DataArrayGroupBy, - ResolvedBinGrouper, + ResolvedGrouper, _validate_groupby_squeeze, ) @@ -6803,7 +6803,7 @@ def groupby_bins( "include_lowest": include_lowest, }, ) - rgrouper = ResolvedBinGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self) return DataArrayGroupBy( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 884e302b8be..418835e815d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10179,13 +10179,13 @@ def groupby( """ from xarray.core.groupby import ( DatasetGroupBy, - ResolvedUniqueGrouper, + ResolvedGrouper, UniqueGrouper, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) return DatasetGroupBy( self, @@ -10265,7 +10265,7 @@ def groupby_bins( from xarray.core.groupby import ( BinGrouper, DatasetGroupBy, - ResolvedBinGrouper, + ResolvedGrouper, _validate_groupby_squeeze, ) @@ -10279,7 +10279,7 @@ def groupby_bins( "include_lowest": include_lowest, }, ) - rgrouper = ResolvedBinGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self) return DatasetGroupBy( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 172b70e2497..54562b6923f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import datetime import warnings from abc import ABC, abstractmethod @@ -54,7 +55,7 @@ GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] T_FactorizeOut = tuple[ - DataArray, T_GroupIndices, Union[IndexVariable, "_DummyGroup"], pd.Index + DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index, DataArray ] @@ -73,7 +74,9 @@ def check_reduce_dims(reduce_dims, dimensions): def _maybe_squeeze_indices( indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool ): - if squeeze in [None, True] and grouper.can_squeeze: + is_unique_grouper = isinstance(grouper.grouper, UniqueGrouper) + can_squeeze = is_unique_grouper and grouper.grouper.can_squeeze + if squeeze in [None, True] and can_squeeze: if isinstance(indices, slice): if indices.stop - indices.start == 1: if (squeeze is None and warn) or squeeze is True: @@ -338,13 +341,11 @@ def _apply_loffset( @dataclass -class ResolvedGrouper(ABC, Generic[T_Xarray]): +class ResolvedGrouper: grouper: Grouper group: T_Group obj: T_Xarray - _group_as_index: pd.Index | None = field(default=None, init=False) - # Defined by factorize: codes: DataArray = field(init=False) group_indices: T_GroupIndices = field(init=False) @@ -358,6 +359,13 @@ class ResolvedGrouper(ABC, Generic[T_Xarray]): inserted_dims: list[Hashable] = field(init=False) def __post_init__(self) -> None: + # This copy allows the BinGrouper.factorize() method + # to update BinGrouper.bins when provided as int, using the output + # of pd.cut + # We do not want to modify the original object, since the same grouper + # might be used multiple times. + self.grouper = copy.deepcopy(self.grouper) + self.group: T_Group = _resolve_group(self.obj, self.group) ( @@ -367,26 +375,25 @@ def __post_init__(self) -> None: self.inserted_dims, ) = _ensure_1d(group=self.group, obj=self.obj) + self.factorize() + @property def name(self) -> Hashable: - return self.group1d.name + # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper + return self.unique_coord.name @property def size(self) -> int: return len(self) def __len__(self) -> int: - return len(self.full_index) # TODO: full_index not def, abstractmethod? + return len(self.full_index) @property def dims(self): return self.group1d.dims - @abstractmethod - def factorize(self) -> T_FactorizeOut: - raise NotImplementedError - - def _factorize(self) -> None: + def factorize(self) -> None: # This design makes it clear to mypy that # codes, group_indices, unique_coord, and full_index # are set by the factorize method on the derived class. @@ -395,7 +402,22 @@ def _factorize(self) -> None: self.group_indices, self.unique_coord, self.full_index, - ) = self.factorize() + ) = self.grouper.factorize(self.group1d) + + +class Grouper(ABC): + @abstractmethod + def factorize(self, group) -> T_FactorizeOut: + pass + + +class Resampler(Grouper): + pass + + +@dataclass +class UniqueGrouper(Grouper): + _group_as_index: pd.Index | None = None @property def is_unique_and_monotonic(self) -> bool: @@ -407,21 +429,17 @@ def is_unique_and_monotonic(self) -> bool: @property def group_as_index(self) -> pd.Index: if self._group_as_index is None: - self._group_as_index = self.group1d.to_index() + self._group_as_index = self.group.to_index() return self._group_as_index @property def can_squeeze(self) -> bool: - is_resampler = isinstance(self.grouper, TimeResampleGrouper) is_dimension = self.group.dims == (self.group.name,) - return not is_resampler and is_dimension and self.is_unique_and_monotonic + return is_dimension and self.is_unique_and_monotonic + def factorize(self, group1d) -> T_FactorizeOut: + self.group = group1d -@dataclass -class ResolvedUniqueGrouper(ResolvedGrouper): - grouper: UniqueGrouper - - def factorize(self) -> T_FactorizeOut: if self.can_squeeze: return self._factorize_dummy() else: @@ -437,7 +455,7 @@ def _factorize_unique(self) -> T_FactorizeOut: raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - codes = self.group1d.copy(data=codes_) + codes = self.group.copy(data=codes_) group_indices = group_indices unique_coord = IndexVariable( self.group.name, unique_values, attrs=self.group.attrs @@ -458,25 +476,38 @@ def _factorize_dummy(self) -> T_FactorizeOut: else: codes = self.group.copy(data=size_range) unique_coord = self.group - full_index = IndexVariable(self.name, unique_coord.values, self.group.attrs) + full_index = IndexVariable( + self.group.name, unique_coord.values, self.group.attrs + ) return codes, group_indices, unique_coord, full_index @dataclass -class ResolvedBinGrouper(ResolvedGrouper): - grouper: BinGrouper +class BinGrouper(Grouper): + bins: Any # TODO: What is the typing? + cut_kwargs: Mapping = field(default_factory=dict) + binned: Any = None + name: Any = None - def factorize(self) -> T_FactorizeOut: + def __post_init__(self) -> None: + if duck_array_ops.isnull(self.bins).all(): + raise ValueError("All bin edges are NaN.") + + def factorize(self, group) -> T_FactorizeOut: from xarray.core.dataarray import DataArray - data = self.group1d.values - binned, bins = pd.cut( - data, self.grouper.bins, **self.grouper.cut_kwargs, retbins=True - ) + data = group.data + + binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + binned_codes = binned.codes if (binned_codes == -1).all(): - raise ValueError(f"None of the data falls within bins with edges {bins!r}") + raise ValueError( + f"None of the data falls within bins with edges {self.bins!r}" + ) + + new_dim_name = f"{group.name}_bins" full_index = binned.categories uniques = np.sort(pd.unique(binned_codes)) @@ -486,22 +517,27 @@ def factorize(self) -> T_FactorizeOut: ] if len(group_indices) == 0: - raise ValueError(f"None of the data falls within bins with edges {bins!r}") + raise ValueError( + f"None of the data falls within bins with edges {self.bins!r}" + ) - new_dim_name = str(self.group.name) + "_bins" - self.group1d = DataArray( - binned, getattr(self.group1d, "coords", None), name=new_dim_name + codes = DataArray( + binned_codes, getattr(group, "coords", None), name=new_dim_name ) - unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) - codes = self.group1d.copy(data=binned_codes) - # TODO: support IntervalIndex in IndexVariable - + unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) return codes, group_indices, unique_coord, full_index @dataclass -class ResolvedTimeResampler(ResolvedGrouper): - grouper: TimeResampler +class TimeResampler(Resampler): + freq: str + closed: SideOptions | None = field(default=None) + label: SideOptions | None = field(default=None) + origin: str | DatetimeLike = field(default="start_day") + offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) + loffset: datetime.timedelta | str | None = field(default=None) + base: int | None = field(default=None) + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) group_as_index: pd.Index = field(init=False) @@ -592,7 +628,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: _apply_loffset(self.loffset, first_items) return first_items, codes - def _factorize(self, group) -> T_FactorizeOut: + def factorize(self, group) -> T_FactorizeOut: self._init_properties(group) full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) @@ -607,54 +643,6 @@ def _factorize(self, group) -> T_FactorizeOut: return codes, group_indices, unique_coord, full_index -class Grouper(ABC): - pass - - -@dataclass -class UniqueGrouper(Grouper): - pass - - -@dataclass -class BinGrouper(Grouper): - bins: Any # TODO: What is the typing? - cut_kwargs: Mapping = field(default_factory=dict) - - def __post_init__(self) -> None: - if duck_array_ops.isnull(self.bins).all(): - raise ValueError("All bin edges are NaN.") - - -@dataclass -class TimeResampler(Grouper): - freq: str - closed: SideOptions | None = field(default=None) - label: SideOptions | None = field(default=None) - origin: str | DatetimeLike = field(default="start_day") - offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) - loffset: datetime.timedelta | str | None = field(default=None) - base: str | None = field(default=None) - - def __post_init__(self): - if self.loffset is not None: - emit_user_level_warning( - "Following pandas, the `loffset` parameter to resample will be deprecated " - "in a future version of xarray. Switch to using time offset arithmetic.", - FutureWarning, - ) - - if self.base is not None: - emit_user_level_warning( - "Following pandas, the `base` parameter to resample will be deprecated in " - "a future version of xarray. Switch to using `origin` or `offset` instead.", - FutureWarning, - ) - - if self.base is not None and self.offset is not None: - raise ValueError("base and offset cannot be present at the same time") - - def _validate_groupby_squeeze(squeeze: bool | None) -> None: # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -795,9 +783,6 @@ def __init__( self._original_obj = obj - for grouper_ in self.groupers: - grouper_._factorize() - (grouper,) = self.groupers self._original_group = grouper.group @@ -1018,7 +1003,7 @@ def _maybe_restore_empty_groups(self, combined): """ (grouper,) = self.groupers if ( - isinstance(grouper, (ResolvedBinGrouper, ResolvedTimeResampler)) + isinstance(grouper.grouper, (BinGrouper, TimeResampler)) and grouper.name in combined.dims ): indexers = {grouper.name: grouper.full_index} @@ -1053,7 +1038,7 @@ def _flox_reduce( obj = self._original_obj (grouper,) = self.groupers - isbin = isinstance(grouper, ResolvedBinGrouper) + isbin = isinstance(grouper.grouper, BinGrouper) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -1435,8 +1420,14 @@ def _restore_dim_order(self, stacked: DataArray) -> DataArray: (grouper,) = self.groupers group = grouper.group1d + groupby_coord = ( + f"{group.name}_bins" + if isinstance(grouper.grouper, BinGrouper) + else group.name + ) + def lookup_order(dimension): - if dimension == group.name: + if dimension == groupby_coord: (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) From 12a44fe4692940bf88c73c7e30964085dbd5fa0c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 22 Feb 2024 20:10:49 -0700 Subject: [PATCH 3/8] Fixes --- xarray/core/groupby.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 54562b6923f..8c6ffd5a797 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -55,7 +55,7 @@ GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] T_FactorizeOut = tuple[ - DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index, DataArray + DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index ] @@ -406,6 +406,10 @@ def factorize(self) -> None: class Grouper(ABC): + @property + def can_squeeze(self) -> bool: + return False + @abstractmethod def factorize(self, group) -> T_FactorizeOut: pass From 2e108c747fbbe36777dd6c0ee512e9f91895a002 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 22 Feb 2024 20:18:43 -0700 Subject: [PATCH 4/8] Fixes --- xarray/core/groupby.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 8c6ffd5a797..d656cf58a6a 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -29,7 +29,13 @@ safe_cast_to_index, ) from xarray.core.options import _get_keep_attrs -from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray +from xarray.core.types import ( + Dims, + QuantileMethods, + T_DataArray, + T_DataWithCoords, + T_Xarray, +) from xarray.core.utils import ( FrozenMappingWarningOnValuesAccess, either_dict_or_kwargs, @@ -276,9 +282,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_Xarray) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, - T_Xarray, + T_DataWithCoords, Hashable | None, list[Hashable], ]: @@ -341,10 +347,10 @@ def _apply_loffset( @dataclass -class ResolvedGrouper: +class ResolvedGrouper(Generic[T_DataWithCoords]): grouper: Grouper group: T_Group - obj: T_Xarray + obj: T_DataWithCoords # Defined by factorize: codes: DataArray = field(init=False) @@ -354,7 +360,7 @@ class ResolvedGrouper: # _ensure_1d: group1d: T_Group = field(init=False) - stacked_obj: T_Xarray = field(init=False) + stacked_obj: T_DataWithCoords = field(init=False) stacked_dim: Hashable | None = field(init=False) inserted_dims: list[Hashable] = field(init=False) @@ -660,7 +666,7 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None: ) -def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: +def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: from xarray.core.dataarray import DataArray error_msg = ( @@ -698,12 +704,12 @@ def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group = obj[group] - if group.name not in obj._indexes and group.name in obj.dims: + group_da: DataArray = obj[group] + if group_da.name not in obj._indexes and group_da.name in obj.dims: # DummyGroups should not appear on groupby results - newgroup = _DummyGroup(obj, group.name, group.coords) + newgroup = _DummyGroup(obj, group_da.name, group_da.coords) else: - newgroup = group + newgroup = group_da if newgroup.size == 0: raise ValueError(f"{newgroup.name} must not be empty") From 1dfcd132d30bfb16cb2b9b68c2723619010f6293 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 22 Feb 2024 20:32:45 -0700 Subject: [PATCH 5/8] Add comments. --- xarray/core/groupby.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d656cf58a6a..dee102ca09b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -348,6 +348,19 @@ def _apply_loffset( @dataclass class ResolvedGrouper(Generic[T_DataWithCoords]): + """ + Wrapper around a Grouper object. + + The Grouper object represents an abstract instruction to group an object. + The ResovledGrouper object is a concrete version that contains all the common + logic necessary for a GroupBy problem including the intermediates necessary for + executing a GroupBy calculation. Specialization to the grouping problem at hand, + is accomplished by calling the `factorize` method on the encapsulated Grouper + object. + + This class is private API, while Groupers are public. + """ + grouper: Grouper group: T_Group obj: T_DataWithCoords @@ -400,9 +413,6 @@ def dims(self): return self.group1d.dims def factorize(self) -> None: - # This design makes it clear to mypy that - # codes, group_indices, unique_coord, and full_index - # are set by the factorize method on the derived class. ( self.codes, self.group_indices, @@ -414,10 +424,22 @@ def factorize(self) -> None: class Grouper(ABC): @property def can_squeeze(self) -> bool: + """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` + should override it.""" return False @abstractmethod def factorize(self, group) -> T_FactorizeOut: + """ + Takes the group, and creates intermediates necessary for GroupBy. + These intermediates are + 1. codes - Same shape as `group` containing a unique integer code for each group. + 2. group_indices - Indexes that let us index out the members of each group. + 3. unique_coord - Unique groups present in the dataset. + 4. full_index - Unique groups in the output. This differes from `unique_coord` in the + case of resampling and binning, where certain groups in the output are not present in + the input. + """ pass @@ -427,6 +449,8 @@ class Resampler(Grouper): @dataclass class UniqueGrouper(Grouper): + """Grouper object for grouping by a categorical variable.""" + _group_as_index: pd.Index | None = None @property @@ -495,6 +519,8 @@ def _factorize_dummy(self) -> T_FactorizeOut: @dataclass class BinGrouper(Grouper): + """Grouper object for binning numeric data.""" + bins: Any # TODO: What is the typing? cut_kwargs: Mapping = field(default_factory=dict) binned: Any = None @@ -540,6 +566,8 @@ def factorize(self, group) -> T_FactorizeOut: @dataclass class TimeResampler(Resampler): + """Grouper object specialized to resampling the time coordinate.""" + freq: str closed: SideOptions | None = field(default=None) label: SideOptions | None = field(default=None) From 19cec62a49d56d8311379a03369f383b158e9683 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 3 Mar 2024 14:19:45 -0700 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index dee102ca09b..3ebde5b7c8b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -352,7 +352,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): Wrapper around a Grouper object. The Grouper object represents an abstract instruction to group an object. - The ResovledGrouper object is a concrete version that contains all the common + The ResolvedGrouper object is a concrete version that contains all the common logic necessary for a GroupBy problem including the intermediates necessary for executing a GroupBy calculation. Specialization to the grouping problem at hand, is accomplished by calling the `factorize` method on the encapsulated Grouper @@ -436,7 +436,7 @@ def factorize(self, group) -> T_FactorizeOut: 1. codes - Same shape as `group` containing a unique integer code for each group. 2. group_indices - Indexes that let us index out the members of each group. 3. unique_coord - Unique groups present in the dataset. - 4. full_index - Unique groups in the output. This differes from `unique_coord` in the + 4. full_index - Unique groups in the output. This differs from `unique_coord` in the case of resampling and binning, where certain groups in the output are not present in the input. """ From 532c1d9025b5a8990493c87b5528bb76b8b918a3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 3 Mar 2024 14:19:56 -0700 Subject: [PATCH 7/8] Review feedback --- xarray/core/groupby.py | 62 +++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3ebde5b7c8b..cb5fb426496 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -346,6 +346,38 @@ def _apply_loffset( result.index = result.index + loffset +class Grouper(ABC): + """Base class for Grouper objects that allow specializing GroupBy instructions.""" + + @property + def can_squeeze(self) -> bool: + """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` + should override it.""" + return False + + @abstractmethod + def factorize(self, group) -> T_FactorizeOut: + """ + Takes the group, and creates intermediates necessary for GroupBy. + These intermediates are + 1. codes - Same shape as `group` containing a unique integer code for each group. + 2. group_indices - Indexes that let us index out the members of each group. + 3. unique_coord - Unique groups present in the dataset. + 4. full_index - Unique groups in the output. This differs from `unique_coord` in the + case of resampling and binning, where certain groups in the output are not present in + the input. + """ + pass + + +class Resampler(Grouper): + """Base class for Grouper objects that allow specializing resampling-type GroupBy instructions. + Currently only used for TimeResampler, but could be used for SpaceResampler in the future. + """ + + pass + + @dataclass class ResolvedGrouper(Generic[T_DataWithCoords]): """ @@ -369,7 +401,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): codes: DataArray = field(init=False) group_indices: T_GroupIndices = field(init=False) unique_coord: IndexVariable | _DummyGroup = field(init=False) - full_index: pd.Index = field(init=False) + full_index: pd.Index = field(init=pd.Index()) # _ensure_1d: group1d: T_Group = field(init=False) @@ -421,32 +453,6 @@ def factorize(self) -> None: ) = self.grouper.factorize(self.group1d) -class Grouper(ABC): - @property - def can_squeeze(self) -> bool: - """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` - should override it.""" - return False - - @abstractmethod - def factorize(self, group) -> T_FactorizeOut: - """ - Takes the group, and creates intermediates necessary for GroupBy. - These intermediates are - 1. codes - Same shape as `group` containing a unique integer code for each group. - 2. group_indices - Indexes that let us index out the members of each group. - 3. unique_coord - Unique groups present in the dataset. - 4. full_index - Unique groups in the output. This differs from `unique_coord` in the - case of resampling and binning, where certain groups in the output are not present in - the input. - """ - pass - - -class Resampler(Grouper): - pass - - @dataclass class UniqueGrouper(Grouper): """Grouper object for grouping by a categorical variable.""" @@ -600,7 +606,7 @@ def __post_init__(self): if self.base is not None and self.offset is not None: raise ValueError("base and offset cannot be present at the same time") - def _init_properties(self, group): + def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex from xarray.core.pdcompat import _convert_base_to_offset From 51ff4053e237a8920efe53ca594a4db4678f2958 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 3 Mar 2024 14:25:44 -0700 Subject: [PATCH 8/8] fix --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cb5fb426496..d34b94e9f33 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -401,7 +401,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): codes: DataArray = field(init=False) group_indices: T_GroupIndices = field(init=False) unique_coord: IndexVariable | _DummyGroup = field(init=False) - full_index: pd.Index = field(init=pd.Index()) + full_index: pd.Index = field(init=False) # _ensure_1d: group1d: T_Group = field(init=False)