diff --git a/xray/core/groupby.py b/xray/core/groupby.py index d650f313059..26fae1a4f54 100644 --- a/xray/core/groupby.py +++ b/xray/core/groupby.py @@ -242,11 +242,17 @@ def apply(self, func, shortcut=False, **kwargs): applied : DataArray The result of splitting, applying and combining this array. """ + def maybe_wrap_array(arr, f): + # in case func lost array's metadata + if isinstance(f, np.ndarray) and f.shape == arr.shape: + return arr.__array_wrap__(f) + else: + return f if shortcut: grouped = self._iter_grouped_shortcut() else: grouped = self._iter_grouped() - applied = (func(arr, **kwargs) for arr in grouped) + applied = (maybe_wrap_array(arr, func(arr, **kwargs)) for arr in grouped) return self._concat(applied, shortcut=shortcut) def _concat(self, applied, shortcut=False): diff --git a/xray/test/test_dataarray.py b/xray/test/test_dataarray.py index ce7932c8fbe..6c8b4413359 100644 --- a/xray/test/test_dataarray.py +++ b/xray/test/test_dataarray.py @@ -862,6 +862,26 @@ def center(x): expected_centered = expected_ds['foo'] self.assertDataArrayAllClose(expected_centered, grouped.apply(center)) + def test_groupby_apply_ndarray(self): + # regression test for #326 + array = self.make_groupby_example_array() + grouped = array.groupby('abc') + actual = grouped.apply(np.asarray) + self.assertDataArrayEqual(array, actual) + + def test_groupby_apply_changes_metadata(self): + def change_metadata(x): + x.coords['x'] = x.coords['x'] * 2 + x.attrs['fruit'] = 'lemon' + return x + + array = self.make_groupby_example_array() + grouped = array.groupby('abc') + actual = grouped.apply(change_metadata) + expected = array.copy() + expected = change_metadata(expected) + self.assertDataArrayEqual(expected, actual) + def test_groupby_math(self): array = self.make_groupby_example_array() for squeeze in [True, False]: