Skip to content

Efficient rolling 'trick' #1978

Closed
Closed
@max-sixty

Description

@max-sixty

Based off http://www.rigtorp.se/2011/01/01/rolling-statistics-numpy.html, we wrote up a function that 'tricks' numpy into presenting an array that looks rolling, but without the O^2 memory requirements

Would people be interested in this going into xarray?

It seems to work really well on a few use-cases, but I imagine it's enough trickery that we might not want to support it in xarray.
And, to be clear, it's strictly worse where we have rolling algos. But where we don't, you get a rolling apply without the python loops.

def rolling_window_numpy(a, window):
    """ Make an array appear to be rolling, but using only a view
    http://www.rigtorp.se/2011/01/01/rolling-statistics-numpy.html
    """
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)


def rolling_window(da, span, dim=None, new_dim='dim_0'):
    """ Adds a rolling dimension to a DataArray using only a view """
    original_dims = da.dims
    da = da.transpose(*tuple(d for d in da.dims if d != dim) + (dim,))

    result = apply_ufunc(
        rolling_window_numpy,
        da,
        output_core_dims=((new_dim,),),
        kwargs=(dict(window=span)))

    return result.transpose(*(original_dims + (new_dim,)))

# tests

import numpy as np
import pandas as pd
import pytest
import xarray as xr


@pytest.fixture
def da(dims):
    return xr.DataArray(
        np.random.rand(5, 10, 15), dims=(list('abc'))).transpose(*dims)


@pytest.fixture(params=[
    list('abc'),
    list('bac'),
    list('cab'),
])
def dims(request):
    return request.param


def test_iterate_imputation_fills_missing(sample_data):
    sample_data.iloc[2, 2] = pd.np.nan
    result = iterate_imputation(sample_data)
    assert result.shape == sample_data.shape
    assert result.notnull().values.all()


def test_rolling_window(da, dims):

    result = rolling_window(da, 3, dim='c', new_dim='x')

    assert result.transpose(*list('abcx')).shape == (5, 10, 13, 3)

    # should be a view, so doesn't have any larger strides
    assert np.max(result.values.strides) == 10 * 15 * 8


def test_rolling_window_values():

    da = xr.DataArray(np.arange(12).reshape(2, 6), dims=('item', 'date'))

    rolling = rolling_window(da, 3, dim='date', new_dim='rolling_date')

    expected = sum([11, 10, 9])
    result = rolling.sum('rolling_date').isel(item=1, date=-1)
    assert result == expected

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions