diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 01c2ce4d..739fad8e 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -15,7 +15,9 @@ Optional, Set, Tuple, + TypeVar, Union, + cast, ) import xarray as xr @@ -145,6 +147,9 @@ # Type for Mapper functions Mapper = Callable[[Union[DataArray, Dataset], str], List[str]] +# Type for decorators +F = TypeVar("F", bound=Callable[..., Any]) + def apply_mapper( mappers: Union[Mapper, Tuple[Mapper, ...]], @@ -168,7 +173,7 @@ def _apply_single_mapper(mapper): try: results = mapper(obj, key) except KeyError as e: - if error: + if error or "I expected only one." in repr(e): raise e else: results = [] @@ -203,19 +208,10 @@ def _apply_single_mapper(mapper): return results -def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str) -> List[str]: - """ Helper method for when we really want only one result per key. """ - results = _get_axis_coord(var, key) - if len(results) > 1: - raise KeyError( - f"Multiple results for {key!r} found: {results!r}. I expected only one." - ) - elif len(results) == 0: - raise KeyError(f"No results found for {key!r}.") - return results - - def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]: + """ + Time variable accessor e.g. 'T.month' + """ """ Helper method for when our key name is of the nature "T.month" and we want to isolate the "T" for coordinate mapping @@ -235,12 +231,11 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List ----- Returns an empty list if there is no frequency extension specified. """ + if "." in key: key, ext = key.split(".", 1) - results = apply_mapper( - (_get_axis_coord, _get_with_standard_name), var, key, error=False - ) + results = apply_mapper((_get_all,), var, key, error=False) if len(results) > 1: raise KeyError(f"Multiple results received for {key}.") return [v + "." + ext for v in results] @@ -313,16 +308,6 @@ def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: return list(results) -def _get_measure_variable( - da: DataArray, key: str, error: bool = True, default: str = None -) -> List[DataArray]: - """ tiny wrapper since xarray does not support providing str for weights.""" - varnames = apply_mapper(_get_measure, da, key, error, default) - if len(varnames) > 1: - raise KeyError(f"Multiple measures found for key {key!r}: {varnames!r}.") - return [da[varnames[0]]] - - def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]: """ Translate from cell measures to appropriate variable name. @@ -361,6 +346,9 @@ def _get_with_standard_name( obj: Union[DataArray, Dataset], name: Union[str, List[str]] ) -> List[str]: """ returns a list of variable names with standard name == name. """ + if name is None: + return [] + varnames = [] if isinstance(obj, DataArray): obj = obj._to_temp_dataset() @@ -372,36 +360,91 @@ def _get_with_standard_name( return varnames +def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitrary measures, or standard names + """ + all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name) + results = apply_mapper(all_mappers, obj, key, error=False, default=None) + return results + + +def _get_dims(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitrary measures, or standard names present in .dims + """ + return [k for k in _get_all(obj, key) if k in obj.dims] + + +def _get_indexes(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitrary measures, or standard names present in .indexes + """ + return [k for k in _get_all(obj, key) if k in obj.indexes] + + +def _get_coords(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitrary measures, or standard names present in .coords + """ + return [k for k in _get_all(obj, key) if k in obj.coords] + + +def _variables(func: F) -> F: + @functools.wraps(func) + def wrapper(obj: Union[DataArray, Dataset], key: str) -> List[DataArray]: + return [obj[k] for k in func(obj, key)] + + return cast(F, wrapper) + + +def _single(func: F) -> F: + @functools.wraps(func) + def wrapper(obj: Union[DataArray, Dataset], key: str): + results = func(obj, key) + if len(results) > 1: + raise KeyError( + f"Multiple results for {key!r} found: {results!r}. I expected only one." + ) + elif len(results) == 0: + raise KeyError(f"No results found for {key!r}.") + return results + + wrapper.__doc__ = ( + func.__doc__.replace("One or more of", "One of") + if func.__doc__ + else func.__doc__ + ) + + return cast(F, wrapper) + + #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_get_axis_coord, _get_with_standard_name), - "dims": (_get_axis_coord, _get_with_standard_name), # transpose - "drop_dims": (_get_axis_coord, _get_with_standard_name), # drop_dims - "dimensions": (_get_axis_coord, _get_with_standard_name), # stack - "dims_dict": (_get_axis_coord, _get_with_standard_name), # swap_dims, rename_dims - "shifts": (_get_axis_coord, _get_with_standard_name), # shift, roll - "pad_width": (_get_axis_coord, _get_with_standard_name), # shift, roll - "names": ( - _get_axis_coord, - _get_measure, - _get_with_standard_name, - ), # set_coords, reset_coords, drop_vars - "labels": (_get_axis_coord, _get_measure, _get_with_standard_name), # drop - "coords": (_get_axis_coord, _get_with_standard_name), # interp - "indexers": (_get_axis_coord, _get_with_standard_name), # sel, isel, reindex - # "indexes": (_get_axis_coord,), # set_index - "dims_or_levels": (_get_axis_coord, _get_with_standard_name), # reset_index - "window": (_get_axis_coord, _get_with_standard_name), # rolling_exp - "coord": (_get_axis_coord_single,), # differentiate, integrate - "group": ( - _get_axis_coord_single, - _get_groupby_time_accessor, - _get_with_standard_name, - ), - "indexer": (_get_axis_coord_single,), # resample - "variables": (_get_axis_coord, _get_with_standard_name), # sortby - "weights": (_get_measure_variable,), # type: ignore - "chunks": (_get_axis_coord, _get_with_standard_name), # chunk + "dim": (_get_dims,), + "dims": (_get_dims,), # transpose + "drop_dims": (_get_dims,), # drop_dims + "dimensions": (_get_dims,), # stack + "dims_dict": (_get_dims,), # swap_dims, rename_dims + "shifts": (_get_dims,), # shift, roll + "pad_width": (_get_dims,), # shift, roll + "names": (_get_all,), # set_coords, reset_coords, drop_vars + "labels": (_get_indexes,), # drop_sel + "coords": (_get_dims,), # interp + "indexers": (_get_dims,), # sel, isel, reindex + # "indexes": (_single(_get_dims),), # set_index this decodes keys but not values + "dims_or_levels": (_get_dims,), # reset_index + "window": (_get_dims,), # rolling_exp + "coord": (_single(_get_coords),), # differentiate, integrate + "group": (_single(_get_all), _get_groupby_time_accessor), # groupby + "indexer": (_single(_get_indexes),), # resample + "variables": (_get_all,), # sortby + "weights": (_variables(_single(_get_all)),), # type: ignore + "chunks": (_get_dims,), # chunk } @@ -430,28 +473,19 @@ def _build_docstring(func): can be used for arguments. """ - # this list will need to be updated any time a new mapper is added - mapper_docstrings = { - _get_axis_coord: f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r}", - _get_axis_coord_single: f"One of {(_AXIS_NAMES + _COORD_NAMES)!r}", - _get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'", - _get_with_standard_name: "Standard names", - _get_measure_variable: f"One of {_CELL_MEASURES!r}", - } - sig = inspect.signature(func) string = "" for k in set(sig.parameters.keys()) & set(_DEFAULT_KEY_MAPPERS): mappers = _DEFAULT_KEY_MAPPERS.get(k, []) docstring = ";\n\t\t\t".join( - mapper_docstrings.get(mapper, "unknown. please open an issue.") + mapper.__doc__ if mapper.__doc__ else "unknown. please open an issue." for mapper in mappers ) string += f"\t\t{k}: {docstring} \n" for param in sig.parameters: if sig.parameters[param].kind is inspect.Parameter.VAR_KEYWORD: - string += f"\t\t{param}: {mapper_docstrings[_get_axis_coord]} \n\n" + string += f"\t\t{param}: {_get_all.__doc__} \n\n" return ( f"\n\tThe following arguments will be processed by cf_xarray: \n{string}" "\n\t----\n\t" @@ -583,12 +617,12 @@ def check_results(names, k): successful = dict.fromkeys(key, False) for k in key: if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES: - names = _get_axis_coord(obj, k) + names = _get_all(obj, k) check_results(names, k) successful[k] = bool(names) coords.extend(names) elif "measures" not in skip and k in accessor._get_all_cell_measures(): - measure = _get_measure(obj, k) + measure = _get_all(obj, k) check_results(measure, k) successful[k] = bool(measure) if measure: @@ -738,9 +772,9 @@ def __init__(self, obj, accessor): def _plot_decorator(self, func): """ This decorator is used to set default kwargs on plotting functions. - - For now, this is setting ``xincrease`` and ``yincrease``. It could set - other arguments in the future. + For now, this can + 1. set ``xincrease`` and ``yincrease``. + 2. automatically set ``x`` or ``y``. """ valid_keys = self.accessor.keys() @@ -764,7 +798,8 @@ def _process_x_or_y(kwargs, key): return kwargs is_line_plot = (func.__name__ == "line") or ( - func.__name__ == "wrapper" and kwargs.get("hue") + func.__name__ == "wrapper" + and (kwargs.get("hue") or self._obj.ndim == 1) ) if is_line_plot: if not kwargs.get("hue"): @@ -787,7 +822,7 @@ def __call__(self, *args, **kwargs): obj=self._obj, attr="plot", accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_get_axis_coord_single,)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), ) return self._plot_decorator(plot)(*args, **kwargs) @@ -799,7 +834,7 @@ def __getattr__(self, attr): obj=self._obj.plot, attr=attr, accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_get_axis_coord_single,)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), # TODO: "extra_decorator" is more complex than I would like it to be. # Not sure if there is a better way though extra_decorator=self._plot_decorator, @@ -907,7 +942,7 @@ def _rewrite_values( # these are valid for .sel, .isel, .coarsen all_mappers = ChainMap( key_mappers, - dict.fromkeys(var_kws, (_get_axis_coord, _get_with_standard_name)), + dict.fromkeys(var_kws, (_get_all,)), ) for key in set(all_mappers) & set(kwargs): @@ -961,8 +996,6 @@ def _rewrite_values( for vkw in var_kws: if vkw in kwargs: maybe_update = { - # TODO: this is assuming key_mappers[k] is always - # _get_axis_coord_single k: apply_mapper( key_mappers[k], self._obj, v, error=False, default=[v] )[0] @@ -1105,17 +1138,14 @@ def axes(self) -> Dict[str, List[str]]: This is useful for checking whether a key is valid for indexing, i.e. that the attributes necessary to allow indexing by that key exist. - However, it will only return the Axis names, not Coordinate names. + However, it will only return the Axis names present in ``.coords``, not Coordinate names. Returns ------- Dictionary of valid Axis names that can be used with ``__getitem__`` or ``.cf[key]``. Will be ("X", "Y", "Z", "T") or a subset thereof. """ - vardict = { - key: apply_mapper(_get_axis_coord, self._obj, key, error=False) - for key in _AXIS_NAMES - } + vardict = {key: _get_coords(self._obj, key) for key in _AXIS_NAMES} return {k: sorted(v) for k, v in vardict.items() if v} @@ -1127,17 +1157,14 @@ def coordinates(self) -> Dict[str, List[str]]: This is useful for checking whether a key is valid for indexing, i.e. that the attributes necessary to allow indexing by that key exist. - However, it will only return the Coordinate names, not Axis names. + However, it will only return the Coordinate names present in ``.coords``, not Axis names. Returns ------- Dictionary of valid Coordinate names that can be used with ``__getitem__`` or ``.cf[key]``. Will be ("longitude", "latitude", "vertical", "time") or a subset thereof. """ - vardict = { - key: apply_mapper(_get_axis_coord, self._obj, key, error=False) - for key in _COORD_NAMES - } + vardict = {key: _get_coords(self._obj, key) for key in _COORD_NAMES} return {k: sorted(v) for k, v in vardict.items() if v} @@ -1164,10 +1191,10 @@ def cell_measures(self) -> Dict[str, List[str]]: da.attrs.get("cell_measures", "") for da in obj.data_vars.values() ] - measures: Dict[str, List[str]] = {} + keys = {} for attr in all_attrs: - for key, value in parse_cell_methods_attr(attr).items(): - measures[key] = measures.setdefault(key, []) + [value] + keys.update(parse_cell_methods_attr(attr)) + measures = {key: _get_all(self._obj, key) for key in keys} return {k: sorted(set(v)) for k, v in measures.items() if v} @@ -1310,16 +1337,11 @@ def rename_like( ourkeys = self.keys() theirkeys = other.cf.keys() - good_keys = set(_COORD_NAMES) & ourkeys & theirkeys - if not good_keys: - raise ValueError( - "No common coordinate variables between these two objects." - ) - + good_keys = ourkeys & theirkeys renamer = {} for key in good_keys: - ours = _get_axis_coord_single(self._obj, key)[0] - theirs = _get_axis_coord_single(other, key)[0] + ours = _single(_get_all)(self._obj, key)[0] + theirs = _single(_get_all)(other, key)[0] renamer[ours] = theirs newobj = self._obj.rename(renamer) @@ -1374,6 +1396,12 @@ def guess_coord_axis(self, verbose: bool = False) -> Union[DataArray, Dataset]: obj[var].attrs = dict(ChainMap(obj[var].attrs, ATTRS[axis])) return obj + def drop(self, *args, **kwargs): + raise NotImplementedError( + "cf-xarray does not support .drop." + "Please use .cf.drop_vars or .cf.drop_sel as appropriate." + ) + @xr.register_dataset_accessor("cf") class CFDatasetAccessor(CFAccessor): @@ -1422,7 +1450,7 @@ def get_bounds(self, key: str) -> DataArray: DataArray """ name = apply_mapper( - _get_axis_coord_single, self._obj, key, error=False, default=[key] + _single(_get_all), self._obj, key, error=False, default=[key] )[0] bounds = self._obj[name].attrs["bounds"] obj = self._maybe_to_dataset() @@ -1599,7 +1627,7 @@ def decode_vertical_coords(self, prefix="z"): import re ds = self._obj - dims = _get_axis_coord(ds, "Z") + dims = _get_dims(ds, "Z") requirements = { "ocean_s_coordinate_g1": {"depth_c", "depth", "s", "C", "eta"}, diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index fea548bb..de593471 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -121,7 +121,7 @@ romsds["temp"] = ( ("ocean_time", "s_rho"), [np.linspace(20, 30, 30)] * 2, - {"coordinates": "z_rho_dummy"}, + {"coordinates": "z_rho_dummy", "standard_name": "sea_water_potential_temperature"}, ) romsds["temp"].encoding["coordinates"] = "s_rho" romsds.coords["z_rho_dummy"] = ( diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 2682ccf9..e54e2199 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -215,10 +215,13 @@ def test_getitem_ancillary_variables(): def test_rename_like(): original = popds.copy(deep=True) - with pytest.raises(KeyError): - popds.cf.rename_like(airds) + # it'll match for axis: X (lon, nlon) and coordinate="longitude" (lon, TLONG) + # so delete the axis attributes + newair = airds.copy(deep=True) + del newair.lon.attrs["axis"] + del newair.lat.attrs["axis"] - renamed = popds.cf["TEMP"].cf.rename_like(airds) + renamed = popds.cf["TEMP"].cf.rename_like(newair) for k in ["TLONG", "TLAT"]: assert k not in renamed.coords assert k in original.coords @@ -228,6 +231,20 @@ def test_rename_like(): assert "lat" in renamed.coords assert renamed.attrs["coordinates"] == "lon lat" + # standard name matching + newroms = romsds.expand_dims(latitude=[1], longitude=[1]).cf.guess_coord_axis() + renamed = popds.cf["UVEL"].cf.rename_like(newroms) + assert renamed.attrs["coordinates"] == "longitude latitude" + assert "longitude" in renamed.coords + assert "latitude" in renamed.coords + assert "ULON" not in renamed.coords + assert "ULAT" not in renamed.coords + + # should change "temp" to "TEMP" + renamed = romsds.cf.rename_like(popds) + assert "temp" not in renamed + assert "TEMP" in renamed + @pytest.mark.parametrize("obj", objects) @pytest.mark.parametrize( @@ -413,6 +430,10 @@ def test_dataarray_plot(): np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data) plt.close() + rv = obj.isel(time=0, lon=1).cf.plot(x="lat") + np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data) + plt.close() + # various line plots and automatic guessing rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line() np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data) @@ -625,7 +646,25 @@ def test_get_bounds_dim_name(): def test_docstring(): assert "One of ('X'" in airds.cf.groupby.__doc__ + assert "Time variable accessor e.g. 'T.month'" in airds.cf.groupby.__doc__ assert "One or more of ('X'" in airds.cf.mean.__doc__ + assert "present in .dims" in airds.cf.drop_dims.__doc__ + assert "present in .coords" in airds.cf.integrate.__doc__ + assert "present in .indexes" in airds.cf.resample.__doc__ + + # Make sure docs are up to date + get_all_doc = cf_xarray.accessor._get_all.__doc__ + all_keys = ( + cf_xarray.accessor._AXIS_NAMES + + cf_xarray.accessor._COORD_NAMES + + cf_xarray.accessor._CELL_MEASURES + ) + expected = f"One or more of {all_keys!r}, or arbitrary measures, or standard names" + assert get_all_doc.split() == expected.split() + for name in ["dims", "indexes", "coords"]: + actual = getattr(cf_xarray.accessor, f"_get_{name}").__doc__ + expected = get_all_doc + f" present in .{name}" + assert actual.split() == expected.split() def _make_names(prefixes): @@ -847,10 +886,11 @@ def test_standard_name_mapper(): expected = da.sortby("label") assert_identical(actual, expected) + assert cf_xarray.accessor._get_with_standard_name(da, None) == [] + @pytest.mark.parametrize("obj", objects) -@pytest.mark.parametrize("attr", ["drop", "drop_vars", "set_coords"]) -@pytest.mark.filterwarnings("ignore:dropping .* using `drop` .* deprecated") +@pytest.mark.parametrize("attr", ["drop_vars", "set_coords"]) def test_drop_vars_and_set_coords(obj, attr): # DataArray object has no attribute set_coords @@ -893,11 +933,30 @@ def test_drop_sel_and_reset_coords(obj): @pytest.mark.parametrize("ds", datasets) def test_drop_dims(ds): + # Add data_var and coord to test _get_dims + ds["lon_var"] = ds["lon"] + ds = ds.assign_coords(lon_coord=ds["lon"]) + # Axis and coordinate for cf_name in ["X", "longitude"]: assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name)) +@pytest.mark.parametrize("ds", datasets) +def test_differentiate(ds): + + # Add data_var and coord to test _get_coords + ds["lon_var"] = ds["lon"] + ds = ds.assign_coords(lon_coord=ds["lon"]) + + # Coordinate + assert_identical(ds.differentiate("lon"), ds.cf.differentiate("lon")) + + # Multiple coords (test error raised by _single) + with pytest.raises(KeyError, match=".*I expected only one."): + assert_identical(ds.differentiate("lon"), ds.cf.differentiate("X")) + + def test_new_standard_name_mappers(): assert_identical(forecast.cf.mean("realization"), forecast.mean("M")) assert_identical( diff --git a/doc/examples/introduction.ipynb b/doc/examples/introduction.ipynb index 3a7c1c94..87d0f3bb 100644 --- a/doc/examples/introduction.ipynb +++ b/doc/examples/introduction.ipynb @@ -982,7 +982,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.8" + "version": "3.9.1" }, "toc": { "base_numbering": 1, diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e6be6541..787d51fb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -9,7 +9,7 @@ v0.4.1 (unreleased) - Replace ``cf.describe()`` with :py:meth:`Dataset.cf.__repr__`. By `Mattia Almansi`_. - Automatically set ``x`` or ``y`` for :py:attr:`DataArray.cf.plot`. By `Deepak Cherian`_. - Added scripts to document :ref:`criteria` with tables. By `Mattia Almansi`_. -- Support for ``.drop()``, ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_. +- Support for ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_. - Support for using ``standard_name`` in more functions. (:pr:`128`) By `Deepak Cherian`_ - Allow :py:meth:`DataArray.cf.__getitem__` with standard names. By `Deepak Cherian`_ - Rewrite the ``values`` of :py:attr:`Dataset.coords` and :py:attr:`Dataset.data_vars` with objects returned