diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7195edef42b..b74b0fb84de 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,10 @@ New Features - Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])`` (:issue:`2852`, :issue:`757`). By `Deepak Cherian `_. +- Optimize ffill, bfill with dask when limit is specified + (:pull:`9771`). + By `Joseph Nowak `_, and + `Patrick Hoefler `. - Allow wrapping ``np.ndarray`` subclasses, e.g. ``astropy.units.Quantity`` (:issue:`9704`, :pull:`9760`). By `Sam Levang `_ and `Tien Vo `_. - Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 7a20728ae2e..8bf9c68b727 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -75,7 +75,7 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals -def push(array, n, axis): +def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push """ @@ -83,33 +83,63 @@ def push(array, n, axis): import numpy as np from xarray.core.duck_array_ops import _push + from xarray.core.nputils import nanlast + + if n is not None and all(n <= size for size in array.chunks[axis]): + return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis) + + # TODO: Replace all this function + # once https://github.com/pydata/xarray/issues/9229 being implemented def _fill_with_last_one(a, b): - # cumreduction apply the push func over all the blocks first so, the only missing part is filling - # the missing values using the last data of the previous chunk - return np.where(~np.isnan(b), b, a) + # cumreduction apply the push func over all the blocks first so, + # the only missing part is filling the missing values using the + # last data of the previous chunk + return np.where(np.isnan(b), a, b) - if n is not None and 0 < n < array.shape[axis] - 1: - arange = da.broadcast_to( - da.arange( - array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype - ).reshape( - tuple(size if i == axis else 1 for i, size in enumerate(array.shape)) - ), - array.shape, - array.chunks, - ) - valid_arange = da.where(da.notnull(array), arange, np.nan) - valid_limits = (arange - push(valid_arange, None, axis)) <= n - # omit the forward fill that violate the limit - return da.where(valid_limits, push(array, None, axis), np.nan) - - # The method parameter makes that the tests for python 3.7 fails. - return da.reductions.cumreduction( - func=_push, + def _dtype_push(a, axis, dtype=None): + # Not sure why the blelloch algorithm force to receive a dtype + return _push(a, axis=axis) + + pushed_array = da.reductions.cumreduction( + func=_dtype_push, binop=_fill_with_last_one, ident=np.nan, x=array, axis=axis, dtype=array.dtype, + method=method, + preop=nanlast, ) + + if n is not None and 0 < n < array.shape[axis] - 1: + + def _reset_cumsum(a, axis, dtype=None): + cumsum = np.cumsum(a, axis=axis) + reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) + return cumsum - reset_points + + def _last_reset_cumsum(a, axis, keepdims=None): + # Take the last cumulative sum taking into account the reset + # This is useful for blelloch method + return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) + + def _combine_reset_cumsum(a, b): + # It is going to sum the previous result until the first + # non nan value + bitmask = np.cumprod(b != 0, axis=axis) + return np.where(bitmask, b + a, b) + + valid_positions = da.reductions.cumreduction( + func=_reset_cumsum, + binop=_combine_reset_cumsum, + ident=0, + x=da.isnan(array, dtype=int), + axis=axis, + dtype=int, + method=method, + preop=_last_reset_cumsum, + ) + pushed_array = da.where(valid_positions <= n, pushed_array, np.nan) + + return pushed_array diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4d01e5bc345..77e62e4c71e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -716,6 +716,7 @@ def first(values, axis, skipna=None): return chunked_nanfirst(values, axis) else: return nputils.nanfirst(values, axis) + return take(values, 0, axis=axis) @@ -729,6 +730,7 @@ def last(values, axis, skipna=None): return chunked_nanlast(values, axis) else: return nputils.nanlast(values, axis) + return take(values, -1, axis=axis) @@ -769,14 +771,14 @@ def _push(array, n: int | None = None, axis: int = -1): return bn.push(array, limit, axis) -def push(array, n, axis): +def push(array, n, axis, method="blelloch"): if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: raise RuntimeError( "ffill & bfill requires bottleneck or numbagg to be enabled." " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." ) if is_duck_dask_array(array): - return dask_array_ops.push(array, n, axis) + return dask_array_ops.push(array, n, axis, method=method) else: return _push(array, n, axis) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index f9d0141ead3..a2f5631ce1b 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1008,7 +1008,8 @@ def test_least_squares(use_dask, skipna): @requires_dask @requires_bottleneck -def test_push_dask(): +@pytest.mark.parametrize("method", ["sequential", "blelloch"]) +def test_push_dask(method): import bottleneck import dask.array @@ -1018,13 +1019,18 @@ def test_push_dask(): expected = bottleneck.push(array, axis=0, n=n) for c in range(1, 11): with raise_if_dask_computes(): - actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n) + actual = push( + dask.array.from_array(array, chunks=c), axis=0, n=n, method=method + ) np.testing.assert_equal(actual, expected) # some chunks of size-1 with NaN with raise_if_dask_computes(): actual = push( - dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n + dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), + axis=0, + n=n, + method=method, ) np.testing.assert_equal(actual, expected)