Skip to content

Commit 7d24702

Browse files
committed
Add tests for rolling(..., pad=False)
1 parent 90228b0 commit 7d24702

File tree

3 files changed

+152
-45
lines changed

3 files changed

+152
-45
lines changed

xarray/tests/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,24 @@ def create_test_data(seed=None, add_attrs=True):
231231
return obj
232232

233233

234+
def get_expected_rolling_indices(count, window, center, pad, stride=1):
235+
# Without padding, we lose (window-1) items from the index, either from the beginning
236+
# (without centering) or from the beginning and end (with centering)
237+
if pad:
238+
start_index = 0
239+
end_index = count
240+
elif center:
241+
start_index = window // 2 # 10 -> 5, 9 -> 4
242+
end_index = count - (window - 1) // 2 # 10 -> 4, 9 -> 4
243+
else:
244+
start_index = window - 1
245+
end_index = count
246+
247+
expected_index = np.arange(start_index, end_index, stride)
248+
249+
return expected_index
250+
251+
234252
_CFTIME_CALENDARS = [
235253
"365_day",
236254
"360_day",

xarray/tests/test_dataarray.py

Lines changed: 99 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
assert_array_equal,
3434
assert_equal,
3535
assert_identical,
36+
get_expected_rolling_indices,
3637
has_dask,
3738
raise_if_dask_computes,
3839
requires_bottleneck,
@@ -6505,7 +6506,7 @@ def test_isin(da):
65056506

65066507
@pytest.mark.parametrize("da", (1, 2), indirect=True)
65076508
@pytest.mark.parametrize("center", (True, False, None))
6508-
@pytest.mark.parametrize("pad", (True, False, None))
6509+
@pytest.mark.parametrize("pad", (True, False))
65096510
@pytest.mark.parametrize("min_periods", (1, 6, None))
65106511
@pytest.mark.parametrize("window", (6, 7))
65116512
def test_rolling_iter(da, center, pad, min_periods, window):
@@ -6610,40 +6611,42 @@ def test_rolling_wrapped_bottleneck(da, name, min_periods):
66106611

66116612
@pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median"))
66126613
@pytest.mark.parametrize("center", (True, False, None))
6613-
@pytest.mark.parametrize("pad", (True, False, None))
6614+
@pytest.mark.parametrize("pad", (True, False))
66146615
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
66156616
def test_rolling_wrapped_bottleneck_center_pad(da, name, center, pad):
66166617
pytest.importorskip("bottleneck", minversion="1.1")
66176618

6619+
window = 7
6620+
count = len(da["time"])
66186621
rolling_obj = da.rolling(time=7, center=center, pad=pad)
66196622
actual = getattr(rolling_obj, name)()["time"]
66206623

6621-
if pad:
6622-
expected = da["time"]
6623-
else:
6624-
if center:
6625-
expected = da["time"][slice(3, -3)]
6626-
else:
6627-
expected = da["time"][slice(6, None)]
6624+
expected_index = get_expected_rolling_indices(count, window, center, pad)
6625+
expected = da["time"][expected_index]
66286626

66296627
assert_equal(actual, expected)
66306628

66316629

66326630
@requires_dask
66336631
@pytest.mark.parametrize("name", ("mean", "count"))
66346632
@pytest.mark.parametrize("center", (True, False, None))
6633+
@pytest.mark.parametrize("pad", (True, False))
66356634
@pytest.mark.parametrize("min_periods", (1, None))
66366635
@pytest.mark.parametrize("window", (7, 8))
66376636
@pytest.mark.parametrize("backend", ["dask"], indirect=True)
6638-
def test_rolling_wrapped_dask(da, name, center, min_periods, window):
6637+
def test_rolling_wrapped_dask(da, name, center, pad, min_periods, window):
66396638
# dask version
6640-
rolling_obj = da.rolling(time=window, min_periods=min_periods, center=center)
6639+
rolling_obj = da.rolling(
6640+
time=window, min_periods=min_periods, center=center, pad=pad
6641+
)
66416642
actual = getattr(rolling_obj, name)().load()
66426643
if name != "count":
66436644
with pytest.warns(DeprecationWarning, match="Reductions are applied"):
66446645
getattr(rolling_obj, name)(dim="time")
66456646
# numpy version
6646-
rolling_obj = da.load().rolling(time=window, min_periods=min_periods, center=center)
6647+
rolling_obj = da.load().rolling(
6648+
time=window, min_periods=min_periods, center=center, pad=pad
6649+
)
66476650
expected = getattr(rolling_obj, name)()
66486651

66496652
# using all-close because rolling over ghost cells introduces some
@@ -6652,39 +6655,52 @@ def test_rolling_wrapped_dask(da, name, center, min_periods, window):
66526655

66536656
# with zero chunked array GH:2113
66546657
rolling_obj = da.chunk().rolling(
6655-
time=window, min_periods=min_periods, center=center
6658+
time=window,
6659+
min_periods=min_periods,
6660+
center=center,
6661+
pad=pad,
66566662
)
66576663
actual = getattr(rolling_obj, name)().load()
66586664
assert_allclose(actual, expected)
66596665

66606666

66616667
@pytest.mark.parametrize("center", (True, None))
6662-
def test_rolling_wrapped_dask_nochunk(center):
6668+
@pytest.mark.parametrize("pad", (True, False))
6669+
def test_rolling_wrapped_dask_nochunk(center, pad):
66636670
# GH:2113
66646671
pytest.importorskip("dask.array")
66656672

66666673
da_day_clim = xr.DataArray(
66676674
np.arange(1, 367), coords=[np.arange(1, 367)], dims="dayofyear"
66686675
)
6669-
expected = da_day_clim.rolling(dayofyear=31, center=center).mean()
6670-
actual = da_day_clim.chunk().rolling(dayofyear=31, center=center).mean()
6676+
expected = da_day_clim.rolling(dayofyear=31, center=center, pad=pad).mean()
6677+
actual = da_day_clim.chunk().rolling(dayofyear=31, center=center, pad=pad).mean()
66716678
assert_allclose(actual, expected)
66726679

66736680

66746681
@pytest.mark.parametrize("center", (True, False))
6682+
@pytest.mark.parametrize("pad", (False,))
66756683
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
6676-
@pytest.mark.parametrize("window", (1, 2, 3, 4))
6677-
def test_rolling_pandas_compat(center, window, min_periods):
6684+
@pytest.mark.parametrize("window", (2, 3, 4))
6685+
def test_rolling_pandas_compat(center, pad, window, min_periods):
66786686
s = pd.Series(np.arange(10))
66796687
da = DataArray.from_series(s)
66806688

66816689
if min_periods is not None and window < min_periods:
66826690
min_periods = window
66836691

6684-
s_rolling = s.rolling(window, center=center, min_periods=min_periods).mean()
6685-
da_rolling = da.rolling(index=window, center=center, min_periods=min_periods).mean()
6692+
expected_index = get_expected_rolling_indices(10, window, center, pad)
6693+
6694+
s_rolling = (
6695+
s.rolling(window, center=center, min_periods=min_periods)
6696+
.mean()
6697+
.iloc[expected_index]
6698+
)
6699+
da_rolling = da.rolling(
6700+
index=window, center=center, pad=pad, min_periods=min_periods
6701+
).mean()
66866702
da_rolling_np = da.rolling(
6687-
index=window, center=center, min_periods=min_periods
6703+
index=window, center=center, pad=pad, min_periods=min_periods
66886704
).reduce(np.nanmean)
66896705

66906706
np.testing.assert_allclose(s_rolling.values, da_rolling.values)
@@ -6694,10 +6710,12 @@ def test_rolling_pandas_compat(center, window, min_periods):
66946710

66956711

66966712
@pytest.mark.parametrize("center", (True, False))
6697-
@pytest.mark.parametrize("window", (1, 2, 3, 4))
6713+
@pytest.mark.parametrize("window", (2, 3, 4))
66986714
def test_rolling_construct(center, window):
6699-
s = pd.Series(np.arange(10))
6715+
count = 10
6716+
s = pd.Series(np.arange(count))
67006717
da = DataArray.from_series(s)
6718+
da = da.assign_coords(time=("index", np.arange(1, count + 1)))
67016719

67026720
s_rolling = s.rolling(window, center=center, min_periods=1).mean()
67036721
da_rolling = da.rolling(index=window, center=center, min_periods=1)
@@ -6718,21 +6736,41 @@ def test_rolling_construct(center, window):
67186736
assert da_rolling_mean.isnull().sum() == 0
67196737
assert (da_rolling_mean == 0.0).sum() >= 0
67206738

6739+
# with no padding
6740+
da_rolling = da.rolling(index=window, center=center, min_periods=1, pad=False)
6741+
da_rolling_mean = da_rolling.construct("window", stride=2).mean("window")
6742+
6743+
expected_index = get_expected_rolling_indices(
6744+
count, window, center, pad=False, stride=2
6745+
)
6746+
6747+
assert da_rolling_mean.sizes["index"] == len(expected_index)
6748+
assert (da_rolling_mean.index.values == expected_index).all()
6749+
assert (da_rolling_mean.time.values == expected_index + 1).all()
6750+
6751+
np.testing.assert_allclose(s_rolling.values[expected_index], da_rolling_mean.values)
6752+
np.testing.assert_allclose(
6753+
s_rolling.index[expected_index], da_rolling_mean["index"]
6754+
)
6755+
67216756

67226757
@pytest.mark.parametrize("da", (1, 2), indirect=True)
67236758
@pytest.mark.parametrize("center", (True, False))
6759+
@pytest.mark.parametrize("pad", (True, False))
67246760
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
67256761
@pytest.mark.parametrize("window", (1, 2, 3, 4))
67266762
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
6727-
def test_rolling_reduce(da, center, min_periods, window, name):
6763+
def test_rolling_reduce(da, center, pad, min_periods, window, name):
67286764
if min_periods is not None and window < min_periods:
67296765
min_periods = window
67306766

67316767
if da.isnull().sum() > 1 and window == 1:
67326768
# this causes all nan slices
67336769
window = 2
67346770

6735-
rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods)
6771+
rolling_obj = da.rolling(
6772+
time=window, center=center, pad=pad, min_periods=min_periods
6773+
)
67366774

67376775
# add nan prefix to numpy methods to get similar # behavior as bottleneck
67386776
actual = rolling_obj.reduce(getattr(np, "nan%s" % name))
@@ -6742,18 +6780,21 @@ def test_rolling_reduce(da, center, min_periods, window, name):
67426780

67436781

67446782
@pytest.mark.parametrize("center", (True, False))
6783+
@pytest.mark.parametrize("pad", (True, False))
67456784
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
67466785
@pytest.mark.parametrize("window", (1, 2, 3, 4))
67476786
@pytest.mark.parametrize("name", ("sum", "max"))
6748-
def test_rolling_reduce_nonnumeric(center, min_periods, window, name):
6787+
def test_rolling_reduce_nonnumeric(center, pad, min_periods, window, name):
67496788
da = DataArray(
67506789
[0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time"
67516790
).isnull()
67526791

67536792
if min_periods is not None and window < min_periods:
67546793
min_periods = window
67556794

6756-
rolling_obj = da.rolling(time=window, center=center, min_periods=min_periods)
6795+
rolling_obj = da.rolling(
6796+
time=window, center=center, pad=pad, min_periods=min_periods
6797+
)
67576798

67586799
# add nan prefix to numpy methods to get similar behavior as bottleneck
67596800
actual = rolling_obj.reduce(getattr(np, "nan%s" % name))
@@ -6767,11 +6808,15 @@ def test_rolling_count_correct():
67676808

67686809
kwargs = [
67696810
{"time": 11, "min_periods": 1},
6811+
{"time": 11, "min_periods": 1, "pad": False},
67706812
{"time": 11, "min_periods": None},
6813+
{"time": 11, "min_periods": None, "pad": False},
67716814
{"time": 7, "min_periods": 2},
6815+
{"time": 7, "min_periods": 2, "pad": False},
67726816
]
67736817
expecteds = [
67746818
DataArray([1, 1, 2, 3, 3, 4, 5, 6, 6, 7, 8], dims="time"),
6819+
DataArray([8], dims="time"),
67756820
DataArray(
67766821
[
67776822
np.nan,
@@ -6788,7 +6833,9 @@ def test_rolling_count_correct():
67886833
],
67896834
dims="time",
67906835
),
6836+
DataArray([np.nan], dims="time"),
67916837
DataArray([np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims="time"),
6838+
DataArray([5, 5, 5, 5, 5], dims="time"),
67926839
]
67936840

67946841
for kwarg, expected in zip(kwargs, expecteds):
@@ -6800,17 +6847,20 @@ def test_rolling_count_correct():
68006847

68016848

68026849
@pytest.mark.parametrize("da", (1,), indirect=True)
6803-
@pytest.mark.parametrize("center", (True, False))
6850+
@pytest.mark.parametrize("center", (True, False, {"time": True, "x": False}))
6851+
@pytest.mark.parametrize("pad", (True, False, {"time": True, "x": False}))
68046852
@pytest.mark.parametrize("min_periods", (None, 1))
68056853
@pytest.mark.parametrize("name", ("sum", "mean", "max"))
6806-
def test_ndrolling_reduce(da, center, min_periods, name):
6807-
rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods)
6854+
def test_ndrolling_reduce(da, center, pad, min_periods, name):
6855+
rolling_obj = da.rolling(
6856+
time=3, x=2, center=center, pad=pad, min_periods=min_periods
6857+
)
68086858

68096859
actual = getattr(rolling_obj, name)()
68106860
expected = getattr(
68116861
getattr(
6812-
da.rolling(time=3, center=center, min_periods=min_periods), name
6813-
)().rolling(x=2, center=center, min_periods=min_periods),
6862+
da.rolling(time=3, center=center, pad=pad, min_periods=min_periods), name
6863+
)().rolling(x=2, center=center, pad=pad, min_periods=min_periods),
68146864
name,
68156865
)()
68166866

@@ -6828,23 +6878,33 @@ def test_ndrolling_reduce(da, center, min_periods, name):
68286878
assert_allclose(actual, expected.where(count >= min_periods))
68296879

68306880

6831-
@pytest.mark.parametrize("center", (True, False, (True, False)))
6881+
@pytest.mark.parametrize("center", (True, False, {"x": True, "z": False}))
6882+
@pytest.mark.parametrize(
6883+
"pad",
6884+
(
6885+
True,
6886+
False,
6887+
{"x": True, "z": False},
6888+
),
6889+
)
68326890
@pytest.mark.parametrize("fill_value", (np.nan, 0.0))
6833-
def test_ndrolling_construct(center, fill_value):
6891+
def test_ndrolling_construct(center, pad, fill_value):
68346892
da = DataArray(
68356893
np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float),
68366894
dims=["x", "y", "z"],
68376895
coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)},
68386896
)
6839-
actual = da.rolling(x=3, z=2, center=center).construct(
6897+
actual = da.rolling(x=3, z=2, center=center, pad=pad).construct(
68406898
x="x1", z="z1", fill_value=fill_value
68416899
)
6842-
if not isinstance(center, tuple):
6843-
center = (center, center)
6900+
if not isinstance(center, dict):
6901+
center = {"x": center, "z": center}
6902+
if not isinstance(pad, dict):
6903+
pad = {"x": pad, "z": pad}
68446904
expected = (
6845-
da.rolling(x=3, center=center[0])
6905+
da.rolling(x=3, center=center["x"], pad=pad["x"])
68466906
.construct(x="x1", fill_value=fill_value)
6847-
.rolling(z=2, center=center[1])
6907+
.rolling(z=2, center=center["z"], pad=pad["z"])
68486908
.construct(z="z1", fill_value=fill_value)
68496909
)
68506910
assert_allclose(actual, expected)

0 commit comments

Comments
 (0)