diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a82d0a9fa2a..b7e632bdfb7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,8 @@ New Features - Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`). By `Michael Niklas `_. +- Support dask arrays in ``first`` and ``last`` reductions. + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index d2d3e4a6d1c..24c5f698a27 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,5 +1,9 @@ from __future__ import annotations +from functools import partial + +from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] + from xarray.core import dtypes, nputils @@ -92,3 +96,36 @@ def _fill_with_last_one(a, b): axis=axis, dtype=array.dtype, ) + + +def _first_last_wrapper(array, *, axis, op, keepdims): + return op(array, axis, keepdims=keepdims) + + +def _first_or_last(darray, axis, op): + import dask.array + + # This will raise the same error message seen for numpy + axis = normalize_axis_index(axis, darray.ndim) + + wrapped_op = partial(_first_last_wrapper, op=op) + return dask.array.reduction( + darray, + chunk=wrapped_op, + aggregate=wrapped_op, + axis=axis, + dtype=darray.dtype, + keepdims=False, # match numpy version + ) + + +def nanfirst(darray, axis): + from xarray.core.duck_array_ops import nanfirst + + return _first_or_last(darray, axis, op=nanfirst) + + +def nanlast(darray, axis): + from xarray.core.duck_array_ops import nanlast + + return _first_or_last(darray, axis, op=nanlast) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 96baf7f96cd..84e66803fe8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -9,7 +9,6 @@ import datetime import inspect import warnings -from functools import partial from importlib import import_module import numpy as np @@ -637,18 +636,14 @@ def cumsum(array, axis=None, **kwargs): return _nd_cum_func(cumsum_1d, array, axis, **kwargs) -_fail_on_dask_array_input_skipna = partial( - fail_on_dask_array_input, - msg="%r with skipna=True is not yet implemented on dask arrays", -) - - def first(values, axis, skipna=None): """Return the first non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN - _fail_on_dask_array_input_skipna(values) - return nanfirst(values, axis) + if is_duck_dask_array(values): + return dask_array_ops.nanfirst(values, axis) + else: + return nanfirst(values, axis) return take(values, 0, axis=axis) @@ -656,8 +651,10 @@ def last(values, axis, skipna=None): """Return the last non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN - _fail_on_dask_array_input_skipna(values) - return nanlast(values, axis) + if is_duck_dask_array(values): + return dask_array_ops.nanlast(values, axis) + else: + return nanlast(values, axis) return take(values, -1, axis=axis) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 80c988ebd4f..2bc413dc21f 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -24,17 +24,29 @@ def _select_along_axis(values, idx, axis): return values[sl] -def nanfirst(values, axis): +def nanfirst(values, axis, keepdims=False): + if isinstance(axis, tuple): + (axis,) = axis axis = normalize_axis_index(axis, values.ndim) idx_first = np.argmax(~pd.isnull(values), axis=axis) - return _select_along_axis(values, idx_first, axis) + result = _select_along_axis(values, idx_first, axis) + if keepdims: + return np.expand_dims(result, axis=axis) + else: + return result -def nanlast(values, axis): +def nanlast(values, axis, keepdims=False): + if isinstance(axis, tuple): + (axis,) = axis axis = normalize_axis_index(axis, values.ndim) rev = (slice(None),) * axis + (slice(None, None, -1),) idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis) - return _select_along_axis(values, idx_last, axis) + result = _select_along_axis(values, idx_last, axis) + if keepdims: + return np.expand_dims(result, axis=axis) + else: + return result def inverse_permutation(indices): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 577debbce21..52a41035faf 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -549,17 +549,22 @@ def test_rolling(self): actual = v.rolling(x=2).mean() self.assertLazyAndAllClose(expected, actual) - def test_groupby_first(self): + @pytest.mark.parametrize("func", ["first", "last"]) + def test_groupby_first_last(self, func): + method = operator.methodcaller(func) u = self.eager_array v = self.lazy_array for coords in [u.coords, v.coords]: coords["ab"] = ("x", ["a", "a", "b", "b"]) - with pytest.raises(NotImplementedError, match=r"dask"): - v.groupby("ab").first() - expected = u.groupby("ab").first() + expected = method(u.groupby("ab")) + + with raise_if_dask_computes(): + actual = method(v.groupby("ab")) + self.assertLazyAndAllClose(expected, actual) + with raise_if_dask_computes(): - actual = v.groupby("ab").first(skipna=False) + actual = method(v.groupby("ab")) self.assertLazyAndAllClose(expected, actual) def test_reindex(self): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index c873c7b76d3..0d6efa2a8d3 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -48,7 +48,11 @@ class TestOps: def setUp(self): self.x = array( [ - [[nan, nan, 2.0, nan], [nan, 5.0, 6.0, nan], [8.0, 9.0, 10.0, nan]], + [ + [nan, nan, 2.0, nan], + [nan, 5.0, 6.0, nan], + [8.0, 9.0, 10.0, nan], + ], [ [nan, 13.0, 14.0, 15.0], [nan, 17.0, 18.0, nan], @@ -128,6 +132,29 @@ def test_all_nan_arrays(self): assert np.isnan(mean([np.nan, np.nan])) +@requires_dask +class TestDaskOps(TestOps): + @pytest.fixture(autouse=True) + def setUp(self): + import dask.array + + self.x = dask.array.from_array( + [ + [ + [nan, nan, 2.0, nan], + [nan, 5.0, 6.0, nan], + [8.0, 9.0, 10.0, nan], + ], + [ + [nan, 13.0, 14.0, 15.0], + [nan, 17.0, 18.0, nan], + [nan, 21.0, nan, nan], + ], + ], + chunks=(2, 1, 2), + ) + + def test_cumsum_1d(): inputs = np.array([0, 1, 2, 3]) expected = np.array([0, 1, 3, 6])