diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5fa78ae76de..7119332405b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -23,6 +23,7 @@ from . import dtypes, duck_array_ops, nputils, ops from ._reductions import DataArrayGroupByReductions, DatasetGroupByReductions +from .alignment import align from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from .concat import concat from .formatting import format_array_flat @@ -309,7 +310,7 @@ class GroupBy(Generic[T_Xarray]): "_squeeze", # Save unstacked object for flox "_original_obj", - "_unstacked_group", + "_original_group", "_bins", ) _obj: T_Xarray @@ -374,7 +375,7 @@ def __init__( group.name = "group" self._original_obj: T_Xarray = obj - self._unstacked_group = group + self._original_group = group self._bins = bins group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) @@ -571,11 +572,22 @@ def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) - obj = self._obj - group = self._group - dim = self._group_dim + if self._bins is None: + obj = self._original_obj + group = self._original_group + dims = group.dims + else: + obj = self._maybe_unstack(self._obj) + group = self._maybe_unstack(self._group) + dims = (self._group_dim,) + if isinstance(group, _DummyGroup): - group = obj[dim] + group = obj[group.name] + coord = group + else: + coord = self._unique_coord + if not isinstance(coord, DataArray): + coord = DataArray(self._unique_coord) name = group.name if not isinstance(other, (Dataset, DataArray)): @@ -592,37 +604,19 @@ def _binary_op(self, other, f, reflexive=False): "is not a dimension on the other argument" ) - try: - expanded = other.sel({name: group}) - except KeyError: - # some labels are absent i.e. other is not aligned - # so we align by reindexing and then rename dimensions. - - # Broadcast out scalars for backwards compatibility - # TODO: get rid of this when fixing GH2145 - for var in other.coords: - if other[var].ndim == 0: - other[var] = ( - other[var].drop_vars(var).expand_dims({name: other.sizes[name]}) - ) - expanded = ( - other.reindex({name: group.data}) - .rename({name: dim}) - .assign_coords({dim: obj[dim]}) - ) + # Broadcast out scalars for backwards compatibility + # TODO: get rid of this when fixing GH2145 + for var in other.coords: + if other[var].ndim == 0: + other[var] = ( + other[var].drop_vars(var).expand_dims({name: other.sizes[name]}) + ) - if self._bins is not None and name == dim and dim not in obj.xindexes: - # When binning by unindexed coordinate we need to reindex obj. - # _full_index is IntervalIndex, so idx will be -1 where - # a value does not belong to any bin. Using IntervalIndex - # accounts for any non-default cut_kwargs passed to the constructor - idx = pd.cut(group, bins=self._full_index).codes - obj = obj.isel({dim: np.arange(group.size)[idx != -1]}) + other, _ = align(other, coord, join="outer") + expanded = other.sel({name: group}) result = g(obj, expanded) - result = self._maybe_unstack(result) - group = self._maybe_unstack(group) if group.ndim > 1: # backcompat: # TODO: get rid of this when fixing GH2145 @@ -632,8 +626,9 @@ def _binary_op(self, other, f, reflexive=False): if isinstance(result, Dataset) and isinstance(obj, Dataset): for var in set(result): - if dim not in obj[var].dims: - result[var] = result[var].transpose(dim, ...) + for d in dims: + if d not in obj[var].dims: + result[var] = result[var].transpose(d, ...) return result def _maybe_restore_empty_groups(self, combined): @@ -695,10 +690,10 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs): # group is only passed by resample group = kwargs.pop("group", None) if group is None: - if isinstance(self._unstacked_group, _DummyGroup): - group = self._unstacked_group.name + if isinstance(self._original_group, _DummyGroup): + group = self._original_group.name else: - group = self._unstacked_group + group = self._original_group unindexed_dims = tuple() if isinstance(group, str):