Skip to content

Commit e828350

Browse files
committed
Merge pull request #329 from shoyer/apply-array-wrap
Dataset.apply works if func returns like-shaped ndarrays
2 parents 3309f62 + 602a210 commit e828350

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

xray/core/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .alignment import align, partial_align
2121
from .coordinates import DatasetCoordinates, Indexes
2222
from .common import ImplementsDatasetReduce, AttrAccessMixin
23-
from .utils import Frozen, SortedKeysDict, ChainMap
23+
from .utils import Frozen, SortedKeysDict, ChainMap, maybe_wrap_array
2424
from .pycompat import iteritems, itervalues, basestring, OrderedDict
2525

2626

@@ -1452,8 +1452,9 @@ def apply(self, func, keep_attrs=False, args=(), **kwargs):
14521452
Coordinates which are no longer used as the dimension of a
14531453
noncoordinate are dropped.
14541454
"""
1455-
variables = OrderedDict((k, func(v, *args, **kwargs))
1456-
for k, v in iteritems(self.data_vars))
1455+
variables = OrderedDict(
1456+
(k, maybe_wrap_array(v, func(v, *args, **kwargs)))
1457+
for k, v in iteritems(self.data_vars))
14571458
attrs = self.attrs if keep_attrs else None
14581459
return type(self)(variables, attrs=attrs)
14591460

xray/core/groupby.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .alignment import concat
77
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
88
from .pycompat import zip
9-
from .utils import peek_at
9+
from .utils import peek_at, maybe_wrap_array
1010
from .variable import Variable, Coordinate
1111

1212

@@ -242,12 +242,6 @@ def apply(self, func, shortcut=False, **kwargs):
242242
applied : DataArray
243243
The result of splitting, applying and combining this array.
244244
"""
245-
def maybe_wrap_array(arr, f):
246-
# in case func lost array's metadata
247-
if isinstance(f, np.ndarray) and f.shape == arr.shape:
248-
return arr.__array_wrap__(f)
249-
else:
250-
return f
251245
if shortcut:
252246
grouped = self._iter_grouped_shortcut()
253247
else:

xray/core/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ def safe_cast_to_index(array):
9696
return index
9797

9898

99+
def maybe_wrap_array(original, new_array):
100+
"""Wrap a transformed array with __array_wrap__ is it can be done safely.
101+
102+
This lets us treat arbitrary functions that take and return ndarray objects
103+
like ufuncs, as long as they return an array with the same shape.
104+
"""
105+
# in case func lost array's metadata
106+
if isinstance(new_array, np.ndarray) and new_array.shape == original.shape:
107+
return original.__array_wrap__(new_array)
108+
else:
109+
return new_array
110+
111+
99112
def equivalent(first, second):
100113
"""Compare two objects for equivalence (identity or equality), using
101114
array_equiv if either object is an ndarray

xray/test/test_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,10 @@ def scale(x, multiple=1):
15141514
self.assertDataArrayEqual(actual['var1'], 2 * data['var1'])
15151515
self.assertDataArrayIdentical(actual['numbers'], data['numbers'])
15161516

1517+
actual = data.apply(np.asarray)
1518+
expected = data.drop_vars('time') # time is not used on a data var
1519+
self.assertDatasetEqual(expected, actual)
1520+
15171521
def make_example_math_dataset(self):
15181522
variables = OrderedDict(
15191523
[('bar', ('x', np.arange(100, 400, 100))),

0 commit comments

Comments
 (0)