diff --git a/xarray/core/common.py b/xarray/core/common.py index cf2b4063202..7b9a049c662 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -16,7 +16,6 @@ from xarray.core.utils import ( Frozen, either_dict_or_kwargs, - emit_user_level_warning, is_scalar, ) from xarray.namedarray.core import _raise_if_any_duplicate_dimensions @@ -1050,8 +1049,7 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import ResolvedTimeResampleGrouper, TimeResampleGrouper - from xarray.core.pdcompat import _convert_base_to_offset + from xarray.core.groupby import ResolvedGrouper, TimeResampler from xarray.core.resample import RESAMPLE_DIM # note: the second argument (now 'skipna') use to be 'dim' @@ -1074,44 +1072,24 @@ def _resample( dim_name: Hashable = dim dim_coord = self[dim] - if loffset is not None: - emit_user_level_warning( - "Following pandas, the `loffset` parameter to resample is deprecated. " - "Switch to updating the resampled dataset time coordinate using " - "time offset arithmetic. For example:\n" - " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" - ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', - FutureWarning, - ) - - if base is not None: - emit_user_level_warning( - "Following pandas, the `base` parameter to resample will be deprecated in " - "a future version of xarray. Switch to using `origin` or `offset` instead.", - FutureWarning, - ) - - if base is not None and offset is not None: - raise ValueError("base and offset cannot be present at the same time") - - if base is not None: - index = self._indexes[dim_name].to_pandas_index() - offset = _convert_base_to_offset(base, freq, index) + group = DataArray( + dim_coord, + coords=dim_coord.coords, + dims=dim_coord.dims, + name=RESAMPLE_DIM, + ) - grouper = TimeResampleGrouper( + grouper = TimeResampler( freq=freq, closed=closed, label=label, origin=origin, offset=offset, loffset=loffset, + base=base, ) - group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM - ) - - rgrouper = ResolvedTimeResampleGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self) return resample_cls( self, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index aeb6b2217c3..7a0bdbc4d4c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6704,13 +6704,13 @@ def groupby( """ from xarray.core.groupby import ( DataArrayGroupBy, - ResolvedUniqueGrouper, + ResolvedGrouper, UniqueGrouper, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) return DataArrayGroupBy( self, (rgrouper,), @@ -6789,7 +6789,7 @@ def groupby_bins( from xarray.core.groupby import ( BinGrouper, DataArrayGroupBy, - ResolvedBinGrouper, + ResolvedGrouper, _validate_groupby_squeeze, ) @@ -6803,7 +6803,7 @@ def groupby_bins( "include_lowest": include_lowest, }, ) - rgrouper = ResolvedBinGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self) return DataArrayGroupBy( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e1fd9e025fb..895a357a240 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10179,13 +10179,13 @@ def groupby( """ from xarray.core.groupby import ( DatasetGroupBy, - ResolvedUniqueGrouper, + ResolvedGrouper, UniqueGrouper, _validate_groupby_squeeze, ) _validate_groupby_squeeze(squeeze) - rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) return DatasetGroupBy( self, @@ -10265,7 +10265,7 @@ def groupby_bins( from xarray.core.groupby import ( BinGrouper, DatasetGroupBy, - ResolvedBinGrouper, + ResolvedGrouper, _validate_groupby_squeeze, ) @@ -10279,7 +10279,7 @@ def groupby_bins( "include_lowest": include_lowest, }, ) - rgrouper = ResolvedBinGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self) return DatasetGroupBy( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ed6c74bc262..d34b94e9f33 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import datetime import warnings from abc import ABC, abstractmethod @@ -28,7 +29,13 @@ safe_cast_to_index, ) from xarray.core.options import _get_keep_attrs -from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray +from xarray.core.types import ( + Dims, + QuantileMethods, + T_DataArray, + T_DataWithCoords, + T_Xarray, +) from xarray.core.utils import ( FrozenMappingWarningOnValuesAccess, either_dict_or_kwargs, @@ -54,7 +61,7 @@ GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] T_FactorizeOut = tuple[ - DataArray, T_GroupIndices, Union[IndexVariable, "_DummyGroup"], pd.Index + DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index ] @@ -73,7 +80,9 @@ def check_reduce_dims(reduce_dims, dimensions): def _maybe_squeeze_indices( indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool ): - if squeeze in [None, True] and grouper.can_squeeze: + is_unique_grouper = isinstance(grouper.grouper, UniqueGrouper) + can_squeeze = is_unique_grouper and grouper.grouper.can_squeeze + if squeeze in [None, True] and can_squeeze: if isinstance(indices, slice): if indices.stop - indices.start == 1: if (squeeze is None and warn) or squeeze is True: @@ -273,9 +282,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_Xarray) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, - T_Xarray, + T_DataWithCoords, Hashable | None, list[Hashable], ]: @@ -337,13 +346,56 @@ def _apply_loffset( result.index = result.index + loffset +class Grouper(ABC): + """Base class for Grouper objects that allow specializing GroupBy instructions.""" + + @property + def can_squeeze(self) -> bool: + """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` + should override it.""" + return False + + @abstractmethod + def factorize(self, group) -> T_FactorizeOut: + """ + Takes the group, and creates intermediates necessary for GroupBy. + These intermediates are + 1. codes - Same shape as `group` containing a unique integer code for each group. + 2. group_indices - Indexes that let us index out the members of each group. + 3. unique_coord - Unique groups present in the dataset. + 4. full_index - Unique groups in the output. This differs from `unique_coord` in the + case of resampling and binning, where certain groups in the output are not present in + the input. + """ + pass + + +class Resampler(Grouper): + """Base class for Grouper objects that allow specializing resampling-type GroupBy instructions. + Currently only used for TimeResampler, but could be used for SpaceResampler in the future. + """ + + pass + + @dataclass -class ResolvedGrouper(ABC, Generic[T_Xarray]): +class ResolvedGrouper(Generic[T_DataWithCoords]): + """ + Wrapper around a Grouper object. + + The Grouper object represents an abstract instruction to group an object. + The ResolvedGrouper object is a concrete version that contains all the common + logic necessary for a GroupBy problem including the intermediates necessary for + executing a GroupBy calculation. Specialization to the grouping problem at hand, + is accomplished by calling the `factorize` method on the encapsulated Grouper + object. + + This class is private API, while Groupers are public. + """ + grouper: Grouper group: T_Group - obj: T_Xarray - - _group_as_index: pd.Index | None = field(default=None, init=False) + obj: T_DataWithCoords # Defined by factorize: codes: DataArray = field(init=False) @@ -353,11 +405,18 @@ class ResolvedGrouper(ABC, Generic[T_Xarray]): # _ensure_1d: group1d: T_Group = field(init=False) - stacked_obj: T_Xarray = field(init=False) + stacked_obj: T_DataWithCoords = field(init=False) stacked_dim: Hashable | None = field(init=False) inserted_dims: list[Hashable] = field(init=False) def __post_init__(self) -> None: + # This copy allows the BinGrouper.factorize() method + # to update BinGrouper.bins when provided as int, using the output + # of pd.cut + # We do not want to modify the original object, since the same grouper + # might be used multiple times. + self.grouper = copy.deepcopy(self.grouper) + self.group: T_Group = _resolve_group(self.obj, self.group) ( @@ -367,35 +426,38 @@ def __post_init__(self) -> None: self.inserted_dims, ) = _ensure_1d(group=self.group, obj=self.obj) + self.factorize() + @property def name(self) -> Hashable: - return self.group1d.name + # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper + return self.unique_coord.name @property def size(self) -> int: return len(self) def __len__(self) -> int: - return len(self.full_index) # TODO: full_index not def, abstractmethod? + return len(self.full_index) @property def dims(self): return self.group1d.dims - @abstractmethod - def factorize(self) -> T_FactorizeOut: - raise NotImplementedError - - def _factorize(self) -> None: - # This design makes it clear to mypy that - # codes, group_indices, unique_coord, and full_index - # are set by the factorize method on the derived class. + def factorize(self) -> None: ( self.codes, self.group_indices, self.unique_coord, self.full_index, - ) = self.factorize() + ) = self.grouper.factorize(self.group1d) + + +@dataclass +class UniqueGrouper(Grouper): + """Grouper object for grouping by a categorical variable.""" + + _group_as_index: pd.Index | None = None @property def is_unique_and_monotonic(self) -> bool: @@ -407,21 +469,17 @@ def is_unique_and_monotonic(self) -> bool: @property def group_as_index(self) -> pd.Index: if self._group_as_index is None: - self._group_as_index = self.group1d.to_index() + self._group_as_index = self.group.to_index() return self._group_as_index @property def can_squeeze(self) -> bool: - is_resampler = isinstance(self.grouper, TimeResampleGrouper) is_dimension = self.group.dims == (self.group.name,) - return not is_resampler and is_dimension and self.is_unique_and_monotonic - + return is_dimension and self.is_unique_and_monotonic -@dataclass -class ResolvedUniqueGrouper(ResolvedGrouper): - grouper: UniqueGrouper + def factorize(self, group1d) -> T_FactorizeOut: + self.group = group1d - def factorize(self) -> T_FactorizeOut: if self.can_squeeze: return self._factorize_dummy() else: @@ -437,7 +495,7 @@ def _factorize_unique(self) -> T_FactorizeOut: raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - codes = self.group1d.copy(data=codes_) + codes = self.group.copy(data=codes_) group_indices = group_indices unique_coord = IndexVariable( self.group.name, unique_values, attrs=self.group.attrs @@ -458,25 +516,40 @@ def _factorize_dummy(self) -> T_FactorizeOut: else: codes = self.group.copy(data=size_range) unique_coord = self.group - full_index = IndexVariable(self.name, unique_coord.values, self.group.attrs) + full_index = IndexVariable( + self.group.name, unique_coord.values, self.group.attrs + ) return codes, group_indices, unique_coord, full_index @dataclass -class ResolvedBinGrouper(ResolvedGrouper): - grouper: BinGrouper +class BinGrouper(Grouper): + """Grouper object for binning numeric data.""" + + bins: Any # TODO: What is the typing? + cut_kwargs: Mapping = field(default_factory=dict) + binned: Any = None + name: Any = None - def factorize(self) -> T_FactorizeOut: + def __post_init__(self) -> None: + if duck_array_ops.isnull(self.bins).all(): + raise ValueError("All bin edges are NaN.") + + def factorize(self, group) -> T_FactorizeOut: from xarray.core.dataarray import DataArray - data = self.group1d.values - binned, bins = pd.cut( - data, self.grouper.bins, **self.grouper.cut_kwargs, retbins=True - ) + data = group.data + + binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + binned_codes = binned.codes if (binned_codes == -1).all(): - raise ValueError(f"None of the data falls within bins with edges {bins!r}") + raise ValueError( + f"None of the data falls within bins with edges {self.bins!r}" + ) + + new_dim_name = f"{group.name}_bins" full_index = binned.categories uniques = np.sort(pd.unique(binned_codes)) @@ -486,58 +559,91 @@ def factorize(self) -> T_FactorizeOut: ] if len(group_indices) == 0: - raise ValueError(f"None of the data falls within bins with edges {bins!r}") + raise ValueError( + f"None of the data falls within bins with edges {self.bins!r}" + ) - new_dim_name = str(self.group.name) + "_bins" - self.group1d = DataArray( - binned, getattr(self.group1d, "coords", None), name=new_dim_name + codes = DataArray( + binned_codes, getattr(group, "coords", None), name=new_dim_name ) - unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) - codes = self.group1d.copy(data=binned_codes) - # TODO: support IntervalIndex in IndexVariable - + unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) return codes, group_indices, unique_coord, full_index @dataclass -class ResolvedTimeResampleGrouper(ResolvedGrouper): - grouper: TimeResampleGrouper +class TimeResampler(Resampler): + """Grouper object specialized to resampling the time coordinate.""" + + freq: str + closed: SideOptions | None = field(default=None) + label: SideOptions | None = field(default=None) + origin: str | DatetimeLike = field(default="start_day") + offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) + loffset: datetime.timedelta | str | None = field(default=None) + base: int | None = field(default=None) + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) + group_as_index: pd.Index = field(init=False) + + def __post_init__(self): + if self.loffset is not None: + emit_user_level_warning( + "Following pandas, the `loffset` parameter to resample is deprecated. " + "Switch to updating the resampled dataset time coordinate using " + "time offset arithmetic. For example:\n" + " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" + ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', + FutureWarning, + ) - def __post_init__(self) -> None: - super().__post_init__() + if self.base is not None: + emit_user_level_warning( + "Following pandas, the `base` parameter to resample will be deprecated in " + "a future version of xarray. Switch to using `origin` or `offset` instead.", + FutureWarning, + ) + + if self.base is not None and self.offset is not None: + raise ValueError("base and offset cannot be present at the same time") + def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex + from xarray.core.pdcompat import _convert_base_to_offset - group_as_index = safe_cast_to_index(self.group) - self._group_as_index = group_as_index + group_as_index = safe_cast_to_index(group) + + if self.base is not None: + # grouper constructor verifies that grouper.offset is None at this point + offset = _convert_base_to_offset(self.base, self.freq, group_as_index) + else: + offset = self.offset if not group_as_index.is_monotonic_increasing: # TODO: sort instead of raising an error raise ValueError("index must be monotonic for resampling") - grouper = self.grouper if isinstance(group_as_index, CFTimeIndex): from xarray.core.resample_cftime import CFTimeGrouper index_grouper = CFTimeGrouper( - freq=grouper.freq, - closed=grouper.closed, - label=grouper.label, - origin=grouper.origin, - offset=grouper.offset, - loffset=grouper.loffset, + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, + loffset=self.loffset, ) else: index_grouper = pd.Grouper( # TODO remove once requiring pandas >= 2.2 - freq=_new_to_legacy_freq(grouper.freq), - closed=grouper.closed, - label=grouper.label, - origin=grouper.origin, - offset=grouper.offset, + freq=_new_to_legacy_freq(self.freq), + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, ) self.index_grouper = index_grouper + self.group_as_index = group_as_index def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() @@ -562,11 +668,12 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: # So for _flox_reduce we avoid one reindex and copy by avoiding # _maybe_restore_empty_groups codes = np.repeat(np.arange(len(first_items)), counts) - if self.grouper.loffset is not None: - _apply_loffset(self.grouper.loffset, first_items) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) return first_items, codes - def factorize(self) -> T_FactorizeOut: + def factorize(self, group) -> T_FactorizeOut: + self._init_properties(group) full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) group_indices: T_GroupIndices = [ @@ -574,43 +681,12 @@ def factorize(self) -> T_FactorizeOut: ] group_indices += [slice(sbins[-1], None)] - unique_coord = IndexVariable( - self.group.name, first_items.index, self.group.attrs - ) - codes = self.group.copy(data=codes_) + unique_coord = IndexVariable(group.name, first_items.index, group.attrs) + codes = group.copy(data=codes_) return codes, group_indices, unique_coord, full_index -class Grouper(ABC): - pass - - -@dataclass -class UniqueGrouper(Grouper): - pass - - -@dataclass -class BinGrouper(Grouper): - bins: Any # TODO: What is the typing? - cut_kwargs: Mapping = field(default_factory=dict) - - def __post_init__(self) -> None: - if duck_array_ops.isnull(self.bins).all(): - raise ValueError("All bin edges are NaN.") - - -@dataclass -class TimeResampleGrouper(Grouper): - freq: str - closed: SideOptions | None - label: SideOptions | None - origin: str | DatetimeLike | None - offset: pd.Timedelta | datetime.timedelta | str | None - loffset: datetime.timedelta | str | None - - def _validate_groupby_squeeze(squeeze: bool | None) -> None: # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -624,7 +700,7 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None: ) -def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: +def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: from xarray.core.dataarray import DataArray error_msg = ( @@ -662,12 +738,12 @@ def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group = obj[group] - if group.name not in obj._indexes and group.name in obj.dims: + group_da: DataArray = obj[group] + if group_da.name not in obj._indexes and group_da.name in obj.dims: # DummyGroups should not appear on groupby results - newgroup = _DummyGroup(obj, group.name, group.coords) + newgroup = _DummyGroup(obj, group_da.name, group_da.coords) else: - newgroup = group + newgroup = group_da if newgroup.size == 0: raise ValueError(f"{newgroup.name} must not be empty") @@ -751,9 +827,6 @@ def __init__( self._original_obj = obj - for grouper_ in self.groupers: - grouper_._factorize() - (grouper,) = self.groupers self._original_group = grouper.group @@ -974,7 +1047,7 @@ def _maybe_restore_empty_groups(self, combined): """ (grouper,) = self.groupers if ( - isinstance(grouper, (ResolvedBinGrouper, ResolvedTimeResampleGrouper)) + isinstance(grouper.grouper, (BinGrouper, TimeResampler)) and grouper.name in combined.dims ): indexers = {grouper.name: grouper.full_index} @@ -1009,7 +1082,7 @@ def _flox_reduce( obj = self._original_obj (grouper,) = self.groupers - isbin = isinstance(grouper, ResolvedBinGrouper) + isbin = isinstance(grouper.grouper, BinGrouper) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -1391,8 +1464,14 @@ def _restore_dim_order(self, stacked: DataArray) -> DataArray: (grouper,) = self.groupers group = grouper.group1d + groupby_coord = ( + f"{group.name}_bins" + if isinstance(grouper.grouper, BinGrouper) + else group.name + ) + def lookup_order(dimension): - if dimension == group.name: + if dimension == groupby_coord: (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension)