diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5ee67efb1da..a0e5090d38a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,8 @@ New Features - :py:func:`open_dataset` and :py:func:`open_mfdataset` now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`). By `Miguel Jimenez `_ and `Wei Ji Leong `_. +- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`). + By `Deepak Cherian `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 571dfbe70ed..8eba0fe7919 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -3,7 +3,7 @@ import numpy as np -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs from .pycompat import dask_array_type from .utils import not_implemented @@ -77,6 +77,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): dataset_fill_value=np.nan, kwargs=kwargs, dask="allowed", + keep_attrs=_get_keep_attrs(default=True), ) # this has no runtime function - these are listed so IDEs know these diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 47d78d93ce4..7b62c2c705f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -42,6 +42,14 @@ _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) +def _first_of_type(args, kind): + """ Return either first object of type 'kind' or raise if not found. """ + for arg in args: + if isinstance(arg, kind): + return arg + raise ValueError("This should be unreachable.") + + class _UFuncSignature: """Core dimensions signature for a given function. @@ -252,8 +260,9 @@ def apply_dataarray_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - if keep_attrs and hasattr(args[0], "name"): - name = args[0].name + if keep_attrs: + first_obj = _first_of_type(args, DataArray) + name = first_obj.name else: name = result_name(args) result_coords = build_output_coords(args, signature, exclude_dims) @@ -270,6 +279,14 @@ def apply_dataarray_vfunc( (coords,) = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) + if keep_attrs: + if isinstance(out, tuple): + for da in out: + # This is adding attrs in place + da._copy_attrs_from(first_obj) + else: + out._copy_attrs_from(first_obj) + return out @@ -390,8 +407,6 @@ def apply_dataset_vfunc( """ from .dataset import Dataset - first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True - if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE: raise TypeError( "to apply an operation to datasets with different " @@ -399,6 +414,9 @@ def apply_dataset_vfunc( "dataset_fill_value argument." ) + if keep_attrs: + first_obj = _first_of_type(args, Dataset) + if len(args) > 1: args = deep_align( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False @@ -417,9 +435,11 @@ def apply_dataset_vfunc( (coord_vars,) = list_of_coords out = _fast_dataset(result_vars, coord_vars) - if keep_attrs and isinstance(first_obj, Dataset): + if keep_attrs: if isinstance(out, tuple): - out = tuple(ds._copy_attrs_from(first_obj) for ds in out) + for ds in out: + # This is adding attrs in place + ds._copy_attrs_from(first_obj) else: out._copy_attrs_from(first_obj) return out @@ -595,6 +615,8 @@ def apply_variable_ufunc( """Apply a ndarray level function over Variable and/or ndarray objects.""" from .variable import Variable, as_compatible_data + first_obj = _first_of_type(args, Variable) + dim_sizes = unified_dim_sizes( (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims ) @@ -734,8 +756,8 @@ def func(*arrays): ) ) - if keep_attrs and isinstance(args[0], Variable): - var.attrs.update(args[0].attrs) + if keep_attrs: + var.attrs.update(first_obj.attrs) output.append(var) if signature.num_outputs == 1: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 94b7f702920..1577f63cbd1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -55,7 +55,7 @@ from .indexes import Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs from .variable import ( IndexVariable, @@ -2734,13 +2734,19 @@ def __rmatmul__(self, other): def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]: @functools.wraps(f) def func(self, *args, **kwargs): + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning ) with np.errstate(all="ignore"): - return self.__array_wrap__(f(self.variable.data, *args, **kwargs)) + da = self.__array_wrap__(f(self.variable.data, *args, **kwargs)) + if keep_attrs: + da.attrs = self.attrs + return da return func diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1777ee356af..a23465b3141 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4394,12 +4394,15 @@ def map( foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773 bar (x) float64 1.0 2.0 """ + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) variables = { k: maybe_wrap_array(v, func(v, *args, **kwargs)) for k, v in self.data_vars.items() } - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for k, v in variables.items(): + v._copy_attrs_from(self.data_vars[k]) attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) @@ -4930,15 +4933,20 @@ def from_dict(cls, d): return obj @staticmethod - def _unary_op(f, keep_attrs=False): + def _unary_op(f): @functools.wraps(f) def func(self, *args, **kwargs): variables = {} + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) for k, v in self._variables.items(): if k in self._coord_names: variables[k] = v else: variables[k] = f(v, *args, **kwargs) + if keep_attrs: + variables[k].attrs = v._attrs attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, attrs=attrs) @@ -5677,11 +5685,11 @@ def _integrate_one(self, coord, datetime_unit=None): @property def real(self): - return self._unary_op(lambda x: x.real, keep_attrs=True)(self) + return self.map(lambda x: x.real, keep_attrs=True) @property def imag(self): - return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) + return self.map(lambda x: x.imag, keep_attrs=True) plot = utils.UncachedAccessor(_Dataset_PlotMethods) diff --git a/xarray/core/options.py b/xarray/core/options.py index 5a78aa10b90..a14473c9b97 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -71,7 +71,7 @@ def _get_keep_attrs(default): return global_choice else: raise ValueError( - "The global option keep_attrs must be one of" " True, False or 'default'." + "The global option keep_attrs must be one of True, False or 'default'." ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c55e61cb816..97d299c1db8 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2111,8 +2111,14 @@ def __array_wrap__(self, obj, context=None): def _unary_op(f): @functools.wraps(f) def func(self, *args, **kwargs): + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) with np.errstate(all="ignore"): - return self.__array_wrap__(f(self.data, *args, **kwargs)) + result = self.__array_wrap__(f(self.data, *args, **kwargs)) + if keep_attrs: + result.attrs = self.attrs + return result return func diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5e0fe13ea52..ba424170349 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -9,7 +9,15 @@ import pytest import xarray as xr -from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast +from xarray import ( + DataArray, + Dataset, + IndexVariable, + Variable, + align, + broadcast, + set_options, +) from xarray.coding.times import CFDatetimeCoder from xarray.convert import from_cdms2 from xarray.core import dtypes @@ -2486,6 +2494,21 @@ def test_assign_attrs(self): assert_identical(new_actual, expected) assert actual.attrs == {"a": 1, "b": 2} + @pytest.mark.parametrize( + "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] + ) + def test_propagate_attrs(self, func): + da = DataArray(self.va) + + # test defaults + assert func(da).attrs == da.attrs + + with set_options(keep_attrs=False): + assert func(da).attrs == {} + + with set_options(keep_attrs=True): + assert func(da).attrs == da.attrs + def test_fillna(self): a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 40e2bdfc6de..f16ce21087d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4473,6 +4473,28 @@ def test_fillna(self): assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs + @pytest.mark.parametrize( + "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] + ) + def test_propagate_attrs(self, func): + + da = DataArray(range(5), name="a", attrs={"attr": "da"}) + ds = Dataset({"a": da}, attrs={"attr": "ds"}) + + # test defaults + assert func(ds).attrs == ds.attrs + with set_options(keep_attrs=False): + assert func(ds).attrs != ds.attrs + assert func(ds).a.attrs != ds.a.attrs + + with set_options(keep_attrs=False): + assert func(ds).attrs != ds.attrs + assert func(ds).a.attrs != ds.a.attrs + + with set_options(keep_attrs=True): + assert func(ds).attrs == ds.attrs + assert func(ds).a.attrs == ds.a.attrs + def test_where(self): ds = Dataset({"a": ("x", range(5))}) expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])}) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index efebe09e2ec..ac821ff111d 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -329,7 +329,8 @@ def test_1d_math(self): assert_array_equal(y - v, 1 - v) # verify attributes are dropped v2 = self.cls(["x"], x, {"units": "meters"}) - assert_identical(base_v, +v2) + with set_options(keep_attrs=False): + assert_identical(base_v, +v2) # binary ops with all variables assert_array_equal(v + v, 2 * v) w = self.cls(["x"], y, {"foo": "bar"}) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 2f582b89bf2..48fad296664 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -320,7 +320,6 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): assert not result.attrs -@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595") @pytest.mark.parametrize("operation", ("sum", "mean")) def test_weighted_operations_keep_attr_da_in_ds(operation): # GH #3595