Skip to content

Allow no padding for rolling windows #5603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
039a9bf
Allow padding to be turned off for rolling windows
kmsquire Jul 9, 2021
90228b0
Fix iteration on rolling windows to handle no padding
kmsquire Jul 14, 2021
c43db78
Add tests for rolling(..., pad=False)
kmsquire Jul 15, 2021
30a0b47
Update xarray/core/rolling.py
kmsquire Jul 19, 2021
0314d2f
Update xarray/core/rolling.py
kmsquire Jul 19, 2021
18562b8
Rename _get_rolling_dim_coords -> _get_output_coords
kmsquire Jul 19, 2021
90ae11d
Add comment better explaining _get_output_coords
kmsquire Jul 19, 2021
33ccbf6
Simplify _get_output_coords
kmsquire Jul 19, 2021
5df7b12
Remove utils.get_slice_offsets (unused)
kmsquire Jul 19, 2021
174750a
Simplify get_pads
kmsquire Jul 20, 2021
564833a
Fix rolling object construction in _count
kmsquire Jul 20, 2021
68b112a
Skip padding in variable.rolling if not required
kmsquire Jul 20, 2021
ea332a6
Fix rolling on DataArrays with nonunique dimension coordinates
kmsquire Jul 20, 2021
ef0ac63
Rename expand_args_to_num_dims -> expand_args_to_dims
kmsquire Jul 20, 2021
cdadc60
Update expand_args_to_dims
kmsquire Jul 20, 2021
a2b235c
Update get_pads() types, comments
kmsquire Jul 20, 2021
b98ab3a
Use "length" instead of "count" to refer to the length of a coordinate
kmsquire Jul 20, 2021
b36a049
Update type of _get_output_dim_selector
kmsquire Jul 20, 2021
61428ed
Merge remote-tracking branch 'upstream/main' into feature/rolling-pad
dcherian Aug 11, 2021
923a598
fixes.
dcherian Aug 11, 2021
fea9baf
Add whats-new
dcherian Aug 11, 2021
2084c08
minor test cleanup
dcherian Aug 11, 2021
c29d1fb
consolidate bottleneck dataarray test
dcherian Aug 11, 2021
b509ebf
consolidate bottleneck dataset tests
dcherian Aug 11, 2021
5a2eadc
more test cleanup
dcherian Aug 11, 2021
28a4ad2
Merge remote-tracking branch 'upstream/main' into feature/rolling-pad
dcherian Mar 27, 2023
b963b29
merge tests
dcherian Mar 28, 2023
d02d2fd
Merge remote-tracking branch 'upstream/main' into feature/rolling-pad
dcherian Mar 29, 2023
072e87c
Merge remote-tracking branch 'upstream/main' into feature/rolling-pad
dcherian Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def rolling(
dim: Mapping[Hashable, int] = None,
min_periods: int = None,
center: Union[bool, Mapping[Hashable, bool]] = False,
pad: Union[bool, Mapping[Hashable, bool]] = True,
keep_attrs: bool = None,
**window_kwargs: int,
):
Expand All @@ -838,6 +839,9 @@ def rolling(
setting min_periods equal to the size of the window.
center : bool or mapping, default: False
Set the labels at the center of the window.
pad : bool or mapping, default: True
Pad the sides of the window with ``NaN``. For different
padding, see ``DataArray.pad`` or ``Dataset.pad``.
**window_kwargs : optional
The keyword arguments form of ``dim``.
One of dim or window_kwargs must be provided.
Expand Down Expand Up @@ -886,11 +890,18 @@ def rolling(
--------
core.rolling.DataArrayRolling
core.rolling.DatasetRolling
DataArray.pad
Dataset.pad
"""

dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
return self._rolling_cls(
self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs
self,
dim,
min_periods=min_periods,
center=center,
pad=pad,
keep_attrs=keep_attrs,
)

def rolling_exp(
Expand Down
166 changes: 143 additions & 23 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ class Rolling:
xarray.DataArray.rolling
"""

__slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs")
_attributes = ("window", "min_periods", "center", "dim", "keep_attrs")
__slots__ = ("obj", "window", "min_periods", "center", "pad", "dim", "keep_attrs")
_attributes = ("window", "min_periods", "center", "pad", "dim", "keep_attrs")

def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
def __init__(
self, obj, windows, min_periods=None, center=False, pad=True, keep_attrs=None
):
"""
Moving window object.

Expand All @@ -66,8 +68,11 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : bool, default: False
center : bool or mapping of hashable to bool, default: False
Set the labels at the center of the window.
pad : bool or mapping of hashable to bool, default: True
Pad the sides of the rolling window with ``NaN``. For different
padding, see ``DataArray.pad`` or ``Dataset.pad``.

Returns
-------
Expand All @@ -81,6 +86,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
self.window.append(w)

self.center = self._mapping_to_list(center, default=False)
self.pad = self._mapping_to_list(pad, default=True)
self.obj = obj

# attributes
Expand All @@ -102,8 +108,10 @@ def __repr__(self):
"""provide a nice str repr of our rolling object"""

attrs = [
"{k}->{v}{c}".format(k=k, v=w, c="(center)" if c else "")
for k, w, c in zip(self.dim, self.window, self.center)
"{k}->{v}{c}{p}".format(
k=k, v=w, c="(center)" if c else "", p="(no pad)" if not p else ""
)
for k, w, c, p in zip(self.dim, self.window, self.center, self.pad)
]
return "{klass} [{attrs}]".format(
klass=self.__class__.__name__, attrs=",".join(attrs)
Expand Down Expand Up @@ -200,11 +208,59 @@ def _get_keep_attrs(self, keep_attrs):

return keep_attrs

def _get_output_coords(self, all_coords=False) -> Dict[str, Any]:
"""Get output coordinates, taking into account window size, window, centering, and padding.

If any of the dimensions are not padded, the output size can be shorter than the input size
along that dimension, so we need to shorten and properly label the corresponding coordinates.

If `all_coords` is False, returns coordinates only for the dimension(s) used for the rolling
window. This is most useful if the coordinates will be used in a `da.sel()` call.

If `all_coords` is True, returns all coordinates. This is most useful for constructing a new
DataArray or Dataset, where the data has already been constructed to be the correct size
along each dimension.
"""
# TODO: do we need to include dims without coordinates in output_coordinate_names
# when all_coords is True? (The code here does not include them.)
output_coord_names = list(self.obj.coords) if all_coords else self.dim
window = self.window
center = self.center
pad = self.pad

def offsets_to_slice(start_offset: int, end_offset: int) -> slice:
# Turn start and end offsets into a slice object
slice_start = None if not start_offset else start_offset
slice_end = None if not end_offset else -end_offset

return slice(slice_start, slice_end)

# Dimensions which require offsets are those which are not padded, but the logic to determine
# the offset is very similar to determining padding sizes.
# So, we invert the `pad` flag(s), call `get_pads()`, and work from there.

offset = [not p for p in pad]
offsets = utils.get_pads(self.dim, window, center, offset)

selector: Dict[str, slice] = {
dim: offsets_to_slice(start_offset, end_offset)
for dim, (start_offset, end_offset) in offsets.items()
}

output_coords = {
k: self.obj.coords[k].isel(selector, missing_dims="ignore")
for k in output_coord_names
}

return output_coords


class DataArrayRolling(Rolling):
__slots__ = ("window_labels",)

def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
def __init__(
self, obj, windows, min_periods=None, center=False, pad=True, keep_attrs=None
):
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
Expand All @@ -221,8 +277,11 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : bool, default: False
center : bool or mapping of hashable to bool, default: False
Set the labels at the center of the window.
pad : bool or mapping of hashable to bool, default: True
Pad the sides of the rolling window with ``NaN``. For different
padding, see ``DataArray.pad``.

Returns
-------
Expand All @@ -236,7 +295,12 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
xarray.Dataset.groupby
"""
super().__init__(
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
obj,
windows,
min_periods=min_periods,
center=center,
pad=pad,
keep_attrs=keep_attrs,
)

# TODO legacy attribute
Expand All @@ -245,11 +309,47 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
def __iter__(self):
if len(self.dim) > 1:
raise ValueError("__iter__ is only supported for 1d-rolling")
stops = np.arange(1, len(self.window_labels) + 1)
starts = stops - int(self.window[0])
starts[: int(self.window[0])] = 0
for (label, start, stop) in zip(self.window_labels, starts, stops):
window = self.obj.isel(**{self.dim[0]: slice(start, stop)})
dim = self.dim[0]
center = self.center[0]
pad = self.pad[0]
window = self.window[0]
center_offset = window // 2 if center else 0

pads = utils.get_pads(self.dim, self.window, self.center, self.pad)
start_pad, end_pad = pads[dim]

# Select the proper subset of labels, based on whether or not to center and/or pad
first_label_idx = 0 if pad else center_offset if center else window - 1
last_label_idx = (
len(self.obj[dim])
if pad or not center
else len(self.obj[dim]) - center_offset
)

labels = (
self.obj[dim][slice(first_label_idx, last_label_idx)]
if self.obj[dim].coords
else np.arange(last_label_idx - first_label_idx)
)

padded_obj = self.obj.pad(pads, mode="constant", constant_values=dtypes.NA)

if pad and not center:
first_stop = 1
last_stop = len(self.obj[dim])
elif pad and center:
first_stop = end_pad + 1
last_stop = len(self.obj[dim]) + end_pad
elif not pad:
first_stop = window
last_stop = len(self.obj[dim])

# These are indicies into the padded array, so we need to add start_pad
stops = np.arange(first_stop, last_stop + 1) + start_pad
starts = stops - window

for (label, start, stop) in zip(labels, starts, stops):
window = padded_obj.isel({self.dim[0]: slice(start, stop)})

counts = window.count(dim=self.dim[0])
window = window.where(counts >= self.min_periods)
Expand Down Expand Up @@ -357,15 +457,21 @@ def _construct(
stride = self._mapping_to_list(stride, default=1)

window = obj.variable.rolling_window(
self.dim, self.window, window_dim, self.center, fill_value=fill_value
self.dim,
self.window,
window_dim,
self.center,
self.pad,
fill_value=fill_value,
)

attrs = obj.attrs if keep_attrs else {}
coords = self._get_output_coords(all_coords=True)

result = DataArray(
window,
dims=obj.dims + tuple(window_dim),
coords=obj.coords,
coords=coords,
attrs=attrs,
name=obj.name,
)
Expand Down Expand Up @@ -466,6 +572,7 @@ def _counts(self, keep_attrs):
self.obj.notnull(keep_attrs=keep_attrs)
.rolling(
center={d: self.center[i] for i, d in enumerate(self.dim)},
pad={d: self.pad[i] for i, d in enumerate(self.dim)},
**{d: w for d, w in zip(self.dim, self.window)},
)
.construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
Expand Down Expand Up @@ -512,8 +619,11 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
values = values[valid]

attrs = self.obj.attrs if keep_attrs else {}
output_dim_coords = self._get_output_coords()

return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name)
return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name).sel(
output_dim_coords
)

def _numpy_or_bottleneck_reduce(
self,
Expand Down Expand Up @@ -561,7 +671,9 @@ def _numpy_or_bottleneck_reduce(
class DatasetRolling(Rolling):
__slots__ = ("rollings",)

def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
def __init__(
self, obj, windows, min_periods=None, center=False, pad=True, keep_attrs=None
):
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
Expand All @@ -580,6 +692,9 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
setting min_periods equal to the size of the window.
center : bool or mapping of hashable to bool, default: False
Set the labels at the center of the window.
pad : bool or mapping of hashable to bool, default: True
Pad the sides of the window with ``NaN``. For different
padding, see ``Dataset.pad``.

Returns
-------
Expand All @@ -592,22 +707,23 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
xarray.Dataset.groupby
xarray.DataArray.groupby
"""
super().__init__(obj, windows, min_periods, center, keep_attrs)
super().__init__(obj, windows, min_periods, center, pad, keep_attrs)
if any(d not in self.obj.dims for d in self.dim):
raise KeyError(self.dim)
# Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on self.dim
dims, center = [], {}
dims, center, pad = [], {}, {}
for i, d in enumerate(self.dim):
if d in da.dims:
dims.append(d)
center[d] = self.center[i]
pad[d] = self.pad[i]

if dims:
w = {d: windows[d] for d in dims}
self.rollings[key] = DataArrayRolling(da, w, min_periods, center)
self.rollings[key] = DataArrayRolling(da, w, min_periods, center, pad)

def _dataset_implementation(self, func, keep_attrs, **kwargs):
from .dataset import Dataset
Expand All @@ -625,7 +741,9 @@ def _dataset_implementation(self, func, keep_attrs, **kwargs):
reduced[key].attrs = {}

attrs = self.obj.attrs if keep_attrs else {}
return Dataset(reduced, coords=self.obj.coords, attrs=attrs)
coords = self._get_output_coords(all_coords=True)

return Dataset(reduced, coords=coords, attrs=attrs)

def reduce(self, func, keep_attrs=None, **kwargs):
"""Reduce the items in this group by applying `func` along some
Expand Down Expand Up @@ -747,7 +865,9 @@ def construct(

attrs = self.obj.attrs if keep_attrs else {}

return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel(
coords = self._get_output_coords(all_coords=True)

return Dataset(dataset, coords=coords, attrs=attrs).isel(
**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}
)

Expand Down
Loading