diff --git a/xarray/core/common.py b/xarray/core/common.py index af935ae15d2..f6abcba1ff0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -949,7 +949,7 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import TimeResampleGrouper + from xarray.core.groupby import ResolvedTimeResampleGrouper, TimeResampleGrouper from xarray.core.resample import RESAMPLE_DIM if keep_attrs is not None: @@ -1012,11 +1012,13 @@ def _resample( group = DataArray( dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM ) + + rgrouper = ResolvedTimeResampleGrouper(grouper, group, self) + return resample_cls( self, - group=group, + (rgrouper,), dim=dim_name, - grouper=grouper, resample_dim=RESAMPLE_DIM, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9af7fcd89a4..356f1029192 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -515,15 +515,16 @@ def apply_groupby_func(func, *args): groupbys = [arg for arg in args if isinstance(arg, GroupBy)] assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] - if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): + (grouper,) = first_groupby.groupers + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): raise ValueError( "apply_ufunc can only perform operations over " "multiple GroupBy objects at once if they are all " "grouped the same way" ) - grouped_dim = first_groupby._group.name - unique_values = first_groupby._unique_coord.values + grouped_dim = grouper.name + unique_values = grouper.unique_coord.values iterators = [] for arg in args: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7ff3a7765c9..2f663c4936a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6478,21 +6478,20 @@ def groupby( core.groupby.DataArrayGroupBy pandas.DataFrame.groupby """ - from xarray.core.groupby import DataArrayGroupBy - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import ( + DataArrayGroupBy, + ResolvedUniqueGrouper, + UniqueGrouper, + _validate_groupby_squeeze, + ) + _validate_groupby_squeeze(squeeze) + rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) return DataArrayGroupBy( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -6563,14 +6562,16 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DataArrayGroupBy + from xarray.core.groupby import ( + BinGrouper, + DataArrayGroupBy, + ResolvedBinGrouper, + _validate_groupby_squeeze, + ) - return DataArrayGroupBy( - self, - group, - squeeze=squeeze, + _validate_groupby_squeeze(squeeze) + grouper = BinGrouper( bins=bins, - restore_coord_dims=restore_coord_dims, cut_kwargs={ "right": right, "labels": labels, @@ -6578,6 +6579,14 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + rgrouper = ResolvedBinGrouper(grouper, group, self) + + return DataArrayGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) def weighted(self, weights: DataArray) -> DataArrayWeighted: """ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7d903d9432d..2336883d0b7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8958,21 +8958,21 @@ def groupby( Dataset.resample DataArray.resample """ - from xarray.core.groupby import DatasetGroupBy - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import ( + DatasetGroupBy, + ResolvedUniqueGrouper, + UniqueGrouper, + _validate_groupby_squeeze, + ) + + _validate_groupby_squeeze(squeeze) + rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self) return DatasetGroupBy( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -9043,14 +9043,16 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import ( + BinGrouper, + DatasetGroupBy, + ResolvedBinGrouper, + _validate_groupby_squeeze, + ) - return DatasetGroupBy( - self, - group, - squeeze=squeeze, + _validate_groupby_squeeze(squeeze) + grouper = BinGrouper( bins=bins, - restore_coord_dims=restore_coord_dims, cut_kwargs={ "right": right, "labels": labels, @@ -9058,6 +9060,14 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + rgrouper = ResolvedBinGrouper(grouper, group, self) + + return DatasetGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) def weighted(self, weights: DataArray) -> DatasetWeighted: """ diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 06788505161..55fe103d41e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -2,8 +2,17 @@ import datetime import warnings +from abc import ABC, abstractmethod from collections.abc import Hashable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, cast +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + Union, +) import numpy as np import pandas as pd @@ -25,7 +34,7 @@ ) from xarray.core.options import _get_keep_attrs from xarray.core.pycompat import integer_types -from xarray.core.types import Dims, QuantileMethods, T_Xarray +from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray from xarray.core.utils import ( either_dict_or_kwargs, hashable, @@ -40,13 +49,16 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import DatetimeLike, SideOptions from xarray.core.utils import Frozen GroupKey = Any - - T_GroupIndicesListInt = list[list[int]] - T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] + GroupIndex = Union[int, slice, list[int]] + T_GroupIndices = list[GroupIndex] + T_FactorizeOut = tuple[ + DataArray, T_GroupIndices, Union[IndexVariable, "_DummyGroup"], pd.Index + ] def check_reduce_dims(reduce_dims, dimensions): @@ -87,8 +99,8 @@ def unique_value_groups( return values, groups, inverse -def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndicesListInt: - groups: T_GroupIndicesListInt = [[] for _ in range(N)] +def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices: + groups: T_GroupIndices = [[] for _ in range(N)] for n, g in enumerate(inverse): if g >= 0: groups[g].append(n) @@ -129,13 +141,13 @@ def _dummy_copy(xarray_obj): return res -def _is_one_or_none(obj): +def _is_one_or_none(obj) -> bool: return obj == 1 or obj is None -def _consolidate_slices(slices): +def _consolidate_slices(slices: list[slice]) -> list[slice]: """Consolidate adjacent slices in a list of slices.""" - result = [] + result: list[slice] = [] last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): @@ -179,13 +191,13 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return newpositions[newpositions != -1] -class _DummyGroup: +class _DummyGroup(Generic[T_Xarray]): """Class for keeping track of grouped dimensions without coordinates. Should not be user visible. """ - __slots__ = ("name", "coords", "size") + __slots__ = ("name", "coords", "size", "dataarray") def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: self.name = name @@ -208,10 +220,17 @@ def values(self) -> range: def data(self) -> range: return range(self.size) + def __array__(self) -> np.ndarray: + return np.arange(self.size) + @property def shape(self) -> tuple[int]: return (self.size,) + @property + def attrs(self) -> dict: + return {} + def __getitem__(self, key): if isinstance(key, tuple): key = key[0] @@ -220,43 +239,41 @@ def __getitem__(self, key): def copy(self, deep: bool = True, data: Any = None): raise NotImplementedError + def to_dataarray(self) -> DataArray: + from xarray.core.dataarray import DataArray + + return DataArray( + data=self.data, dims=(self.name,), coords=self.coords, name=self.name + ) -T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) + +T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] def _ensure_1d( group: T_Group, obj: T_Xarray -) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: +) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable],]: # 1D cases: do nothing - from xarray.core.dataarray import DataArray - if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: return group, obj, None, [] + from xarray.core.dataarray import DataArray + if isinstance(group, DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] - # the copy is necessary here, otherwise read only array raises error - # in pandas: https://github.com/pydata/pandas/issues/12813 - newgroup = group.stack({stacked_dim: orig_dims}).copy() + newgroup = group.stack({stacked_dim: orig_dims}) newobj = obj.stack({stacked_dim: orig_dims}) - return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims + return newgroup, newobj, stacked_dim, inserted_dims raise TypeError( f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." ) -def _unique_and_monotonic(group: T_Group) -> bool: - if isinstance(group, _DummyGroup): - return True - index = safe_cast_to_index(group) - return index.is_unique and index.is_monotonic_increasing - - def _apply_loffset( loffset: str | pd.DateOffset | datetime.timedelta | pd.Timedelta, result: pd.Series | pd.DataFrame, @@ -294,92 +311,336 @@ def _apply_loffset( result.index = result.index + loffset -def _get_index_and_items(index, grouper): - first_items, codes = grouper.first_items(index) - full_index = first_items.index - if first_items.isnull().any(): - first_items = first_items.dropna() - return full_index, first_items, codes - - -def _factorize_grouper( - group, grouper -) -> tuple[ - DataArray | IndexVariable | _DummyGroup, - T_GroupIndices, - np.ndarray, - pd.Index, -]: - index = safe_cast_to_index(group) - if not index.is_monotonic_increasing: - # TODO: sort instead of raising an error - raise ValueError("index must be monotonic for resampling") - full_index, first_items, codes = _get_index_and_items(index, grouper) - sbins = first_items.values.astype(np.int64) - group_indices: T_GroupIndices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) - ] + [slice(sbins[-1], None)] - unique_coord = IndexVariable(group.name, first_items.index) - return unique_coord, group_indices, codes, full_index - - -def _factorize_bins( - group, bins, cut_kwargs: Mapping | None -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray, pd.IntervalIndex, DataArray]: - from xarray.core.dataarray import DataArray +@dataclass +class ResolvedGrouper(ABC, Generic[T_Xarray]): + grouper: Grouper + group: T_Group + obj: T_Xarray - if cut_kwargs is None: - cut_kwargs = {} - - if duck_array_ops.isnull(bins).all(): - raise ValueError("All bin edges are NaN.") - binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) - codes = binned.codes - if (codes == -1).all(): - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - full_index = binned.categories - uniques = np.sort(pd.unique(codes)) - unique_values = full_index[uniques[uniques != -1]] - group_indices = [g for g in _codes_to_groups(codes, len(full_index)) if g] - - if len(group_indices) == 0: - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - - new_dim_name = str(group.name) + "_bins" - group_ = DataArray(binned, getattr(group, "coords", None), name=new_dim_name) - unique_coord = IndexVariable(new_dim_name, unique_values) - return unique_coord, group_indices, codes, full_index, group_ - - -def _factorize_rest( - group, -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray]: - # look through group to find the unique values - group_as_index = safe_cast_to_index(group) - sort = not isinstance(group_as_index, pd.MultiIndex) - unique_values, group_indices, codes = unique_value_groups(group_as_index, sort=sort) - if len(group_indices) == 0: - raise ValueError( - "Failed to group data. Are you grouping by a variable that is all NaN?" + _group_as_index: pd.Index | None = field(default=None, init=False) + + # Defined by factorize: + codes: DataArray = field(init=False) + group_indices: T_GroupIndices = field(init=False) + unique_coord: IndexVariable | _DummyGroup = field(init=False) + full_index: pd.Index = field(init=False) + + # _ensure_1d: + group1d: T_Group = field(init=False) + stacked_obj: T_Xarray = field(init=False) + stacked_dim: Hashable | None = field(init=False) + inserted_dims: list[Hashable] = field(init=False) + + def __post_init__(self) -> None: + self.group: T_Group = _resolve_group(self.obj, self.group) + + ( + self.group1d, + self.stacked_obj, + self.stacked_dim, + self.inserted_dims, + ) = _ensure_1d(group=self.group, obj=self.obj) + + @property + def name(self) -> Hashable: + return self.group1d.name + + @property + def size(self) -> int: + return len(self) + + def __len__(self) -> int: + return len(self.full_index) # TODO: full_index not def, abstractmethod? + + @property + def dims(self): + return self.group1d.dims + + @abstractmethod + def _factorize(self, squeeze: bool) -> T_FactorizeOut: + raise NotImplementedError + + def factorize(self, squeeze: bool) -> None: + # This design makes it clear to mypy that + # codes, group_indices, unique_coord, and full_index + # are set by the factorize method on the derived class. + ( + self.codes, + self.group_indices, + self.unique_coord, + self.full_index, + ) = self._factorize(squeeze) + + @property + def is_unique_and_monotonic(self) -> bool: + if isinstance(self.group, _DummyGroup): + return True + index = self.group_as_index + return index.is_unique and index.is_monotonic_increasing + + @property + def group_as_index(self) -> pd.Index: + if self._group_as_index is None: + self._group_as_index = safe_cast_to_index(self.group1d) + return self._group_as_index + + +@dataclass +class ResolvedUniqueGrouper(ResolvedGrouper): + grouper: UniqueGrouper + + def _factorize(self, squeeze) -> T_FactorizeOut: + is_dimension = self.group.dims == (self.group.name,) + if is_dimension and self.is_unique_and_monotonic: + return self._factorize_dummy(squeeze) + else: + return self._factorize_unique() + + def _factorize_unique(self) -> T_FactorizeOut: + # look through group to find the unique values + sort = not isinstance(self.group_as_index, pd.MultiIndex) + unique_values, group_indices, codes_ = unique_value_groups( + self.group_as_index, sort=sort + ) + if len(group_indices) == 0: + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = self.group1d.copy(data=codes_) + group_indices = group_indices + unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs + ) + full_index = unique_coord + + return codes, group_indices, unique_coord, full_index + + def _factorize_dummy(self, squeeze) -> T_FactorizeOut: + size = self.group.size + # no need to factorize + if not squeeze: + # use slices to do views instead of fancy indexing + # equivalent to: group_indices = group_indices.reshape(-1, 1) + group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] + else: + group_indices = list(range(size)) + size_range = np.arange(size) + if isinstance(self.group, _DummyGroup): + codes = self.group.to_dataarray().copy(data=size_range) + else: + codes = self.group.copy(data=size_range) + unique_coord = self.group + full_index = IndexVariable(self.name, unique_coord.values, self.group.attrs) + + return codes, group_indices, unique_coord, full_index + + +@dataclass +class ResolvedBinGrouper(ResolvedGrouper): + grouper: BinGrouper + + def _factorize(self, squeeze: bool) -> T_FactorizeOut: + from xarray.core.dataarray import DataArray + + data = self.group1d.values + binned, bins = pd.cut( + data, self.grouper.bins, **self.grouper.cut_kwargs, retbins=True + ) + binned_codes = binned.codes + if (binned_codes == -1).all(): + raise ValueError(f"None of the data falls within bins with edges {bins!r}") + + full_index = binned.categories + uniques = np.sort(pd.unique(binned_codes)) + unique_values = full_index[uniques[uniques != -1]] + group_indices = [ + g for g in _codes_to_groups(binned_codes, len(full_index)) if g + ] + + if len(group_indices) == 0: + raise ValueError(f"None of the data falls within bins with edges {bins!r}") + + new_dim_name = str(self.group.name) + "_bins" + self.group1d = DataArray( + binned, getattr(self.group1d, "coords", None), name=new_dim_name ) - unique_coord = IndexVariable(group.name, unique_values) - return unique_coord, group_indices, codes - - -def _factorize_dummy( - group, squeeze: bool -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray]: - # no need to factorize - group_indices: T_GroupIndices - if not squeeze: - # use slices to do views instead of fancy indexing - # equivalent to: group_indices = group_indices.reshape(-1, 1) - group_indices = [slice(i, i + 1) for i in range(group.size)] + unique_coord = IndexVariable(new_dim_name, unique_values, self.group.attrs) + codes = self.group1d.copy(data=binned_codes) + # TODO: support IntervalIndex in IndexVariable + + return codes, group_indices, unique_coord, full_index + + +@dataclass +class ResolvedTimeResampleGrouper(ResolvedGrouper): + grouper: TimeResampleGrouper + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) + + def __post_init__(self) -> None: + super().__post_init__() + + from xarray import CFTimeIndex + + group_as_index = safe_cast_to_index(self.group) + self._group_as_index = group_as_index + + if not group_as_index.is_monotonic_increasing: + # TODO: sort instead of raising an error + raise ValueError("index must be monotonic for resampling") + + grouper = self.grouper + if isinstance(group_as_index, CFTimeIndex): + from xarray.core.resample_cftime import CFTimeGrouper + + index_grouper = CFTimeGrouper( + freq=grouper.freq, + closed=grouper.closed, + label=grouper.label, + origin=grouper.origin, + offset=grouper.offset, + loffset=grouper.loffset, + ) + else: + index_grouper = pd.Grouper( + freq=grouper.freq, + closed=grouper.closed, + label=grouper.label, + origin=grouper.origin, + offset=grouper.offset, + ) + self.index_grouper = index_grouper + + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: + first_items, codes = self.first_items() + full_index = first_items.index + if first_items.isnull().any(): + first_items = first_items.dropna() + + full_index = full_index.rename("__resample_dim__") + return full_index, first_items, codes + + def first_items(self) -> tuple[pd.Series, np.ndarray]: + from xarray import CFTimeIndex + + if isinstance(self.group_as_index, CFTimeIndex): + return self.index_grouper.first_items(self.group_as_index) + else: + s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) + grouped = s.groupby(self.index_grouper) + first_items = grouped.first() + counts = grouped.count() + # This way we generate codes for the final output index: full_index. + # So for _flox_reduce we avoid one reindex and copy by avoiding + # _maybe_restore_empty_groups + codes = np.repeat(np.arange(len(first_items)), counts) + if self.grouper.loffset is not None: + _apply_loffset(self.grouper.loffset, first_items) + return first_items, codes + + def _factorize(self, squeeze: bool) -> T_FactorizeOut: + full_index, first_items, codes_ = self._get_index_and_items() + sbins = first_items.values.astype(np.int64) + group_indices: T_GroupIndices = [ + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) + ] + group_indices += [slice(sbins[-1], None)] + + unique_coord = IndexVariable( + self.group.name, first_items.index, self.group.attrs + ) + codes = self.group.copy(data=codes_) + + return codes, group_indices, unique_coord, full_index + + +class Grouper(ABC): + pass + + +@dataclass +class UniqueGrouper(Grouper): + pass + + +@dataclass +class BinGrouper(Grouper): + bins: Any # TODO: What is the typing? + cut_kwargs: Mapping = field(default_factory=dict) + + def __post_init__(self) -> None: + if duck_array_ops.isnull(self.bins).all(): + raise ValueError("All bin edges are NaN.") + + +@dataclass +class TimeResampleGrouper(Grouper): + freq: str + closed: SideOptions | None + label: SideOptions | None + origin: str | DatetimeLike | None + offset: pd.Timedelta | datetime.timedelta | str | None + loffset: datetime.timedelta | str | None + + +def _validate_groupby_squeeze(squeeze: bool) -> None: + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied") + + +def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group: + from xarray.core.dataarray import DataArray + + error_msg = ( + "the group variable's length does not " + "match the length of this variable along its " + "dimensions" + ) + + newgroup: T_Group + if isinstance(group, DataArray): + try: + align(obj, group, join="exact", copy=False) + except ValueError: + raise ValueError(error_msg) + + newgroup = group.copy(deep=False) + newgroup.name = group.name or "group" + + elif isinstance(group, IndexVariable): + # This assumption is built in to _ensure_1d. + if group.ndim != 1: + raise ValueError( + "Grouping by multi-dimensional IndexVariables is not allowed." + "Convert to and pass a DataArray instead." + ) + (group_dim,) = group.dims + if len(group) != obj.sizes[group_dim]: + raise ValueError(error_msg) + else: - group_indices = np.arange(group.size) - codes = np.arange(group.size) - unique_coord = group - return unique_coord, group_indices, codes + if not hashable(group): + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension. " + f"Received {group!r} instead." + ) + group = obj[group] + if group.name not in obj._indexes and group.name in obj.dims: + # DummyGroups should not appear on groupby results + newgroup = _DummyGroup(obj, group.name, group.coords) + else: + newgroup = group + + if newgroup.size == 0: + raise ValueError(f"{newgroup.name} must not be empty") + + return newgroup class GroupBy(Generic[T_Xarray]): @@ -406,6 +667,7 @@ class GroupBy(Generic[T_Xarray]): "_group_dim", "_group_indices", "_groups", + "groupers", "_obj", "_restore_coord_dims", "_stacked_dim", @@ -420,16 +682,26 @@ class GroupBy(Generic[T_Xarray]): "_codes", ) _obj: T_Xarray + groupers: tuple[ResolvedGrouper] + _squeeze: bool + _restore_coord_dims: bool + + _original_obj: T_Xarray + _original_group: T_Group + _group_indices: T_GroupIndices + _codes: DataArray + _group_dim: Hashable + + _groups: dict[GroupKey, GroupIndex] | None + _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None + _sizes: Frozen[Hashable, int] | None def __init__( self, obj: T_Xarray, - group: Hashable | DataArray | IndexVariable, + groupers: tuple[ResolvedGrouper], squeeze: bool = False, - grouper: pd.Grouper | None = None, - bins: ArrayLike | None = None, restore_coord_dims: bool = True, - cut_kwargs: Mapping[Any, Any] | None = None, ) -> None: """Create a GroupBy object @@ -437,100 +709,36 @@ def __init__( ---------- obj : Dataset or DataArray Object to group. - group : Hashable, DataArray or Index - Array with the group values or name of the variable. - squeeze : bool, default: False - If "group" is a coordinate of object, `squeeze` controls whether - the subarrays have a dimension of length 1 along that coordinate or - if the dimension is squeezed out. - grouper : pandas.Grouper, optional - Used for grouping values along the `group` array. - bins : array-like, optional - If `bins` is specified, the groups will be discretized into the - specified bins by `pandas.cut`. + grouper : Grouper + Grouper object restore_coord_dims : bool, default: True If True, also restore the dimension order of multi-dimensional coordinates. - cut_kwargs : dict-like, optional - Extra keyword arguments to pass to `pandas.cut` - """ - from xarray.core.dataarray import DataArray - - if grouper is not None and bins is not None: - raise TypeError("can't specify both `grouper` and `bins`") - - if not isinstance(group, (DataArray, IndexVariable)): - if not hashable(group): - raise TypeError( - "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension. " - f"Received {group!r} instead." - ) - group = obj[group] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") + self.groupers = groupers - if group.name not in obj.coords and group.name in obj.dims: - # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) + self._original_obj = obj - if getattr(group, "name", None) is None: - group.name = "group" + for grouper_ in self.groupers: + grouper_.factorize(squeeze) - self._original_obj: T_Xarray = obj - self._original_group = group - self._bins = bins - - group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) - (group_dim,) = group.dims - - expected_size = obj.sizes[group_dim] - if group.size != expected_size: - raise ValueError( - "the group variable's length does not " - "match the length of this variable along its " - "dimension" - ) - - self._codes: DataArray - if grouper is not None: - unique_coord, group_indices, codes, full_index = _factorize_grouper( - group, grouper - ) - self._codes = group.copy(data=codes) - elif bins is not None: - unique_coord, group_indices, codes, full_index, group = _factorize_bins( - group, bins, cut_kwargs - ) - self._codes = group.copy(data=codes) - elif group.dims == (group.name,) and _unique_and_monotonic(group): - unique_coord, group_indices, codes = _factorize_dummy(group, squeeze) - full_index = None - self._codes = obj[group.name].copy(data=codes) - else: - unique_coord, group_indices, codes = _factorize_rest(group) - full_index = None - self._codes = group.copy(data=codes) + (grouper,) = self.groupers + self._original_group = grouper.group # specification for the groupby operation - self._obj: T_Xarray = obj - self._group = group - self._group_dim = group_dim - self._group_indices = group_indices - self._unique_coord = unique_coord - self._stacked_dim = stacked_dim - self._inserted_dims = inserted_dims - self._full_index = full_index + self._obj = grouper.stacked_obj self._restore_coord_dims = restore_coord_dims - self._bins = bins self._squeeze = squeeze - self._codes = self._maybe_unstack(self._codes) + # These should generalize to multiple groupers + self._group_indices = grouper.group_indices + self._codes = self._maybe_unstack(grouper.codes) + + (self._group_dim,) = grouper.group1d.dims # cached attributes - self._groups: dict[GroupKey, slice | int | list[int]] | None = None - self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None - self._sizes: Frozen[Hashable, int] | None = None + self._groups = None + self._dims = None + self._sizes = None @property def sizes(self) -> Frozen[Hashable, int]: @@ -573,13 +781,14 @@ def reduce( raise NotImplementedError() @property - def groups(self) -> dict[GroupKey, slice | int | list[int]]: + def groups(self) -> dict[GroupKey, GroupIndex]: """ Mapping from group labels to indices. The indices can be used to index the underlying object. """ # provided to mimic pandas.groupby if self._groups is None: - self._groups = dict(zip(self._unique_coord.values, self._group_indices)) + (grouper,) = self.groupers + self._groups = dict(zip(grouper.unique_coord.values, self._group_indices)) return self._groups def __getitem__(self, key: GroupKey) -> T_Xarray: @@ -589,17 +798,20 @@ def __getitem__(self, key: GroupKey) -> T_Xarray: return self._obj.isel({self._group_dim: self.groups[key]}) def __len__(self) -> int: - return self._unique_coord.size + (grouper,) = self.groupers + return grouper.size def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: - return zip(self._unique_coord.values, self._iter_grouped()) + (grouper,) = self.groupers + return zip(grouper.unique_coord.data, self._iter_grouped()) def __repr__(self) -> str: + (grouper,) = self.groupers return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( self.__class__.__name__, - self._unique_coord.name, - self._unique_coord.size, - ", ".join(format_array_flat(self._unique_coord, 30).split()), + grouper.name, + grouper.full_index.size, + ", ".join(format_array_flat(grouper.full_index, 30).split()), ) def _iter_grouped(self) -> Iterator[T_Xarray]: @@ -608,11 +820,12 @@ def _iter_grouped(self) -> Iterator[T_Xarray]: yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): + (grouper,) = self.groupers if self._group_dim in applied_example.dims: - coord = self._group + coord = grouper.group1d positions = self._group_indices else: - coord = self._unique_coord + coord = grouper.unique_coord positions = None (dim,) = coord.dims if isinstance(coord, _DummyGroup): @@ -626,19 +839,19 @@ def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) + (grouper,) = self.groupers obj = self._original_obj - group = self._original_group + group = grouper.group codes = self._codes dims = group.dims if isinstance(group, _DummyGroup): - group = obj[group.name] - coord = group + group = coord = group.to_dataarray() else: - coord = self._unique_coord + coord = grouper.unique_coord if not isinstance(coord, DataArray): - coord = DataArray(self._unique_coord) - name = self._group.name + coord = DataArray(grouper.unique_coord) + name = grouper.name if not isinstance(other, (Dataset, DataArray)): raise TypeError( @@ -667,11 +880,12 @@ def _binary_op(self, other, f, reflexive=False): mask = codes == -1 if mask.any(): obj = obj.where(~mask, drop=True) + group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) # codes are defined for coord, so we align `other` with `coord` # before indexing - other, _ = align(other, coord, join="right") + other, _ = align(other, coord, join="right", copy=False) expanded = other.isel({name: codes}) result = g(obj, expanded) @@ -691,20 +905,27 @@ def _binary_op(self, other, f, reflexive=False): return result def _maybe_restore_empty_groups(self, combined): - """Our index contained empty groups (e.g., from a resampling). If we + """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. """ - if self._full_index is not None and self._group.name in combined.dims: - indexers = {self._group.name: self._full_index} + (grouper,) = self.groupers + if ( + isinstance(grouper, (ResolvedBinGrouper, ResolvedTimeResampleGrouper)) + and grouper.name in combined.dims + ): + indexers = {grouper.name: grouper.full_index} combined = combined.reindex(**indexers) return combined def _maybe_unstack(self, obj): """This gets called if we are applying on an array with a multidimensional group.""" - if self._stacked_dim is not None and self._stacked_dim in obj.dims: - obj = obj.unstack(self._stacked_dim) - for dim in self._inserted_dims: + (grouper,) = self.groupers + stacked_dim = grouper.stacked_dim + inserted_dims = grouper.inserted_dims + if stacked_dim is not None and stacked_dim in obj.dims: + obj = obj.unstack(stacked_dim) + for dim in inserted_dims: if dim in obj.coords: del obj.coords[dim] obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) @@ -722,7 +943,8 @@ def _flox_reduce( from xarray.core.dataset import Dataset obj = self._original_obj - group = self._original_group + (grouper,) = self.groupers + isbin = isinstance(grouper, ResolvedBinGrouper) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -745,22 +967,20 @@ def _flox_reduce( # weird backcompat # reducing along a unique indexed dimension with squeeze=True # should raise an error - if ( - dim is None or dim == self._group.name - ) and self._group.name in obj.xindexes: - index = obj.indexes[self._group.name] + if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes: + index = obj.indexes[grouper.name] if index.is_unique and self._squeeze: - raise ValueError(f"cannot reduce over dimensions {self._group.name!r}") + raise ValueError(f"cannot reduce over dimensions {grouper.name!r}") unindexed_dims: tuple[Hashable, ...] = tuple() - if isinstance(group, _DummyGroup) and self._bins is None: - unindexed_dims = (group.name,) + if isinstance(grouper.group, _DummyGroup) and not isbin: + unindexed_dims = (grouper.name,) parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): parsed_dim = (dim,) elif dim is None: - parsed_dim = group.dims + parsed_dim = grouper.group.dims elif dim is ...: parsed_dim = tuple(obj.dims) else: @@ -768,12 +988,12 @@ def _flox_reduce( # Do this so we raise the same error message whether flox is present or not. # Better to control it here than in flox. - if any(d not in group.dims and d not in obj.dims for d in parsed_dim): + if any(d not in grouper.group.dims and d not in obj.dims for d in parsed_dim): raise ValueError(f"cannot reduce over dimensions {dim}.") if kwargs["func"] not in ["all", "any", "count"]: kwargs.setdefault("fill_value", np.nan) - if self._bins is not None and kwargs["func"] == "count": + if isbin and kwargs["func"] == "count": # This is an annoying hack. Xarray returns np.nan # when there are no observations in a bin, instead of 0. # We can fake that here by forcing min_count=1. @@ -782,7 +1002,7 @@ def _flox_reduce( kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) - output_index = self._get_output_index() + output_index = grouper.full_index result = xarray_reduce( obj.drop_vars(non_numeric.keys()), self._codes, @@ -796,35 +1016,29 @@ def _flox_reduce( # we did end up reducing over dimension(s) that are # in the grouped variable - if set(self._codes.dims).issubset(set(parsed_dim)): - result[self._unique_coord.name] = output_index + group_dims = grouper.group.dims + if set(group_dims).issubset(set(parsed_dim)): + result[grouper.name] = output_index result = result.drop_vars(unindexed_dims) # broadcast and restore non-numeric data variables (backcompat) for name, var in non_numeric.items(): if all(d not in var.dims for d in parsed_dim): result[name] = var.variable.set_dims( - (group.name,) + var.dims, (result.sizes[group.name],) + var.shape + (grouper.name,) + var.dims, + (result.sizes[grouper.name],) + var.shape, ) - if self._bins is not None: + if isbin: # Fix dimension order when binning a dimension coordinate # Needed as long as we do a separate code path for pint; # For some reason Datasets and DataArrays behave differently! - if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: - result = result.transpose(self._group.name, ...) + (group_dim,) = grouper.dims + if isinstance(self._obj, Dataset) and group_dim in self._obj.dims: + result = result.transpose(grouper.name, ...) return result - def _get_output_index(self) -> pd.Index: - """Return pandas.Index object for the output array.""" - if self._full_index is not None: - # binning and resample - return self._full_index.rename(self._unique_coord.name) - if isinstance(self._unique_coord, _DummyGroup): - return IndexVariable(self._group.name, self._unique_coord.values) - return self._unique_coord - def fillna(self, value: Any) -> T_Xarray: """Fill missing values in this object by group. @@ -978,7 +1192,8 @@ def quantile( The American Statistician, 50(4), pp. 361-365, 1996 """ if dim is None: - dim = (self._group_dim,) + (grouper,) = self.groupers + dim = grouper.group1d.dims return self.map( self._obj.__class__.quantile, @@ -1081,13 +1296,18 @@ def _concat_shortcut(self, applied, dim, positions=None): # TODO: benbovy - explicit indexes: this fast implementation doesn't # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) - reordered = _maybe_reorder(stacked, dim, positions, N=self._group.size) + + (grouper,) = self.groupers + reordered = _maybe_reorder(stacked, dim, positions, N=grouper.group.size) return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked: DataArray) -> DataArray: + (grouper,) = self.groupers + group = grouper.group1d + def lookup_order(dimension): - if dimension == self._group.name: - (dimension,) = self._group.dims + if dimension == group.name: + (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) else: @@ -1172,7 +1392,8 @@ def _combine(self, applied, shortcut=False): combined = self._concat_shortcut(applied, dim, positions) else: combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions, N=self._group.size) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) if isinstance(combined, type(self._obj)): # only restore dimension order for arrays @@ -1328,7 +1549,8 @@ def _combine(self, applied): applied_example, applied = peek_at(applied) coord, dim, positions = self._infer_concat_args(applied_example) combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions, N=self._group.size) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) # assign coord when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: index, index_vars = create_default_index_implicit(coord) @@ -1415,56 +1637,3 @@ class DatasetGroupBy( # type: ignore[misc] ImplementsDatasetReduce, ): __slots__ = () - - -class TimeResampleGrouper: - def __init__( - self, - freq: str, - closed: SideOptions | None, - label: SideOptions | None, - origin: str | DatetimeLike, - offset: pd.Timedelta | datetime.timedelta | str | None, - loffset: datetime.timedelta | str | None, - ): - self.freq = freq - self.closed = closed - self.label = label - self.origin = origin - self.offset = offset - self.loffset = loffset - - def first_items(self, index): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper - - if isinstance(index, CFTimeIndex): - grouper = CFTimeGrouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - loffset=self.loffset, - ) - return grouper.first_items(index) - else: - s = pd.Series(np.arange(index.size), index, copy=False) - grouper = pd.Grouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - ) - - grouped = s.groupby(grouper) - first_items = grouped.first() - counts = grouped.count() - # This way we generate codes for the final output index: full_index. - # So for _flox_reduce we avoid one reindex and copy by avoiding - # _maybe_restore_empty_groups - codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) - return first_items, codes diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ad9b8379322..d78676b188e 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -84,8 +84,9 @@ def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: padded : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="pad", tolerance=tolerance + {self._dim: grouper.full_index}, method="pad", tolerance=tolerance ) ffill = pad @@ -108,8 +109,9 @@ def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray backfilled : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="backfill", tolerance=tolerance + {self._dim: grouper.full_index}, method="backfill", tolerance=tolerance ) bfill = backfill @@ -133,8 +135,9 @@ def nearest(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: upsampled : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="nearest", tolerance=tolerance + {self._dim: grouper.full_index}, method="nearest", tolerance=tolerance ) def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: @@ -170,8 +173,9 @@ def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: def _interpolate(self, kind="linear") -> T_Xarray: """Apply scipy.interpolate.interp1d along resampling dimension.""" obj = self._drop_coords() + (grouper,) = self.groupers return obj.interp( - coords={self._dim: self._full_index}, + coords={self._dim: grouper.full_index}, assume_sorted=True, method=kind, kwargs={"bounds_error": False}, diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 4258f6fe6d5..a8530d85235 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -51,8 +51,9 @@ def test_consolidate_slices() -> None: slices = [slice(2, 3), slice(5, 6)] assert _consolidate_slices(slices) == slices + # ignore type because we're checking for an error anyway with pytest.raises(ValueError): - _consolidate_slices([slice(3), 4]) + _consolidate_slices([slice(3), 4]) # type: ignore[list-item] def test_groupby_dims_property(dataset) -> None: @@ -538,7 +539,7 @@ def test_groupby_drops_nans() -> None: .reset_index("id", drop=True) .assign(id=stacked.id.values) .dropna("id") - .transpose(*actual2.dims) + .transpose(*actual2.variable.dims) ) assert_identical(actual2, expected2) @@ -1801,7 +1802,7 @@ def test_upsample(self): # Nearest rs = array.resample(time="3H") actual = rs.nearest() - new_times = rs._full_index + new_times = rs.groupers[0].full_index expected = DataArray(array.reindex(time=new_times, method="nearest")) assert_identical(expected, actual)