Skip to content

Dataset.apply works if func returns like-shaped ndarrays #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 23, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions xray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .alignment import align, partial_align
from .coordinates import DatasetCoordinates, Indexes
from .common import ImplementsDatasetReduce, AttrAccessMixin
from .utils import Frozen, SortedKeysDict, ChainMap
from .utils import Frozen, SortedKeysDict, ChainMap, maybe_wrap_array
from .pycompat import iteritems, itervalues, basestring, OrderedDict


Expand Down Expand Up @@ -1452,8 +1452,9 @@ def apply(self, func, keep_attrs=False, args=(), **kwargs):
Coordinates which are no longer used as the dimension of a
noncoordinate are dropped.
"""
variables = OrderedDict((k, func(v, *args, **kwargs))
for k, v in iteritems(self.data_vars))
variables = OrderedDict(
(k, maybe_wrap_array(v, func(v, *args, **kwargs)))
for k, v in iteritems(self.data_vars))
attrs = self.attrs if keep_attrs else None
return type(self)(variables, attrs=attrs)

Expand Down
8 changes: 1 addition & 7 deletions xray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .alignment import concat
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
from .pycompat import zip
from .utils import peek_at
from .utils import peek_at, maybe_wrap_array
from .variable import Variable, Coordinate


Expand Down Expand Up @@ -242,12 +242,6 @@ def apply(self, func, shortcut=False, **kwargs):
applied : DataArray
The result of splitting, applying and combining this array.
"""
def maybe_wrap_array(arr, f):
# in case func lost array's metadata
if isinstance(f, np.ndarray) and f.shape == arr.shape:
return arr.__array_wrap__(f)
else:
return f
if shortcut:
grouped = self._iter_grouped_shortcut()
else:
Expand Down
13 changes: 13 additions & 0 deletions xray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ def safe_cast_to_index(array):
return index


def maybe_wrap_array(original, new_array):
"""Wrap a transformed array with __array_wrap__ is it can be done safely.

This lets us treat arbitrary functions that take and return ndarray objects
like ufuncs, as long as they return an array with the same shape.
"""
# in case func lost array's metadata
if isinstance(new_array, np.ndarray) and new_array.shape == original.shape:
return original.__array_wrap__(new_array)
else:
return new_array


def equivalent(first, second):
"""Compare two objects for equivalence (identity or equality), using
array_equiv if either object is an ndarray
Expand Down
4 changes: 4 additions & 0 deletions xray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,10 @@ def scale(x, multiple=1):
self.assertDataArrayEqual(actual['var1'], 2 * data['var1'])
self.assertDataArrayIdentical(actual['numbers'], data['numbers'])

actual = data.apply(np.asarray)
expected = data.drop_vars('time') # time is not used on a data var
self.assertDatasetEqual(expected, actual)

def make_example_math_dataset(self):
variables = OrderedDict(
[('bar', ('x', np.arange(100, 400, 100))),
Expand Down