diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 25c572edd54..36366f80004 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -82,7 +82,14 @@ def rolling_window(a, axis, window, center, fill_value): chunks = list(a.chunks) chunks[axis] = (pad_size, ) fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks) - a = da.concatenate([fill_array, a], axis=axis) + # Add the chunk from fill_array at the end because this is where + # the array will be cropped. This way the size of all chunks + # along `axis` is preserved in the end. + # GH 2514 + rechunk_chunks = list(a.chunks) + rechunk_chunks[axis] = rechunk_chunks[axis] + (pad_size,) + a = da.concatenate([fill_array, a], axis=axis).rechunk( + {axis: rechunk_chunks[axis]}) boundary = {d: fill_value for d in range(a.ndim)} diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e49b6cdf517..21269964662 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3491,6 +3491,29 @@ def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): assert_allclose(actual, expected) +@pytest.mark.parametrize('name', ('count',)) +@pytest.mark.parametrize('center', (True, False, None)) +@pytest.mark.parametrize('min_periods', (1, None)) +@pytest.mark.parametrize('window', (7, 70)) +def test_rolling_wrapped_dask_chunksizes(da_dask, name, center, min_periods, + window): + # check if chunksizes are preserved (GH: 2514) + t = pd.date_range(start='2018-01-01', end='2018-02-01', freq='H') + bar = np.sin(np.arange(len(t))) + baz = np.cos(np.arange(len(t))) + + da_test = xr.DataArray(data=np.stack([bar, baz]), + coords={'time': t, + 'sensor': ['one', 'two']}, + dims=('sensor', 'time')) + + rolling_obj = da_test.chunk({'time': 100}).rolling(time=window, + min_periods=min_periods, + center=center) + actual = getattr(rolling_obj, name)() + assert actual.chunks == ((2,), (100, 100, 100, 100, 100, 100, 100, 45)) + + @pytest.mark.parametrize('center', (True, None)) def test_rolling_wrapped_dask_nochunk(center): # GH:2113 diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 47224e55473..32f96ccdc3b 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -247,6 +247,19 @@ def test_interpolate_limits(): assert_equal(actual, expected) +@requires_dask +def test_interpolate_limits_chunksize(): + # GH: 2514 + da = xr.DataArray(np.array([1, 2, np.nan, np.nan, np.nan, 6], + dtype=np.float64), dims='x').chunk({'x': 6}) + + actual = da.interpolate_na(dim='x', limit=None) + assert actual.chunks == (6,) + + actual = da.interpolate_na(dim='x', limit=2) + assert actual.chunks == (6,) + + @requires_scipy def test_interpolate_methods(): for method in ['linear', 'nearest', 'zero', 'slinear', 'quadratic',