Closed
Description
Based off http://www.rigtorp.se/2011/01/01/rolling-statistics-numpy.html, we wrote up a function that 'tricks' numpy into presenting an array that looks rolling, but without the O^2 memory requirements
Would people be interested in this going into xarray?
It seems to work really well on a few use-cases, but I imagine it's enough trickery that we might not want to support it in xarray.
And, to be clear, it's strictly worse where we have rolling algos. But where we don't, you get a rolling apply
without the python loops.
def rolling_window_numpy(a, window):
""" Make an array appear to be rolling, but using only a view
http://www.rigtorp.se/2011/01/01/rolling-statistics-numpy.html
"""
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
def rolling_window(da, span, dim=None, new_dim='dim_0'):
""" Adds a rolling dimension to a DataArray using only a view """
original_dims = da.dims
da = da.transpose(*tuple(d for d in da.dims if d != dim) + (dim,))
result = apply_ufunc(
rolling_window_numpy,
da,
output_core_dims=((new_dim,),),
kwargs=(dict(window=span)))
return result.transpose(*(original_dims + (new_dim,)))
# tests
import numpy as np
import pandas as pd
import pytest
import xarray as xr
@pytest.fixture
def da(dims):
return xr.DataArray(
np.random.rand(5, 10, 15), dims=(list('abc'))).transpose(*dims)
@pytest.fixture(params=[
list('abc'),
list('bac'),
list('cab'),
])
def dims(request):
return request.param
def test_iterate_imputation_fills_missing(sample_data):
sample_data.iloc[2, 2] = pd.np.nan
result = iterate_imputation(sample_data)
assert result.shape == sample_data.shape
assert result.notnull().values.all()
def test_rolling_window(da, dims):
result = rolling_window(da, 3, dim='c', new_dim='x')
assert result.transpose(*list('abcx')).shape == (5, 10, 13, 3)
# should be a view, so doesn't have any larger strides
assert np.max(result.values.strides) == 10 * 15 * 8
def test_rolling_window_values():
da = xr.DataArray(np.arange(12).reshape(2, 6), dims=('item', 'date'))
rolling = rolling_window(da, 3, dim='date', new_dim='rolling_date')
expected = sum([11, 10, 9])
result = rolling.sum('rolling_date').isel(item=1, date=-1)
assert result == expected
Metadata
Metadata
Assignees
Labels
No labels