From f00da4579e2b00c3cb1577fc6fe78aca5b87423e Mon Sep 17 00:00:00 2001 From: malmans2 Date: Thu, 11 Feb 2021 18:56:15 +0000 Subject: [PATCH 01/22] add get dims, coords, indexes --- cf_xarray/accessor.py | 96 ++++++++++++++++++++++++++++++++----------- 1 file changed, 72 insertions(+), 24 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 01c2ce4d..ced60af4 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -249,7 +249,9 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List return [] -def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: +def _get_axis_coord( + var: Union[DataArray, Dataset], key: str, error: bool = True +) -> List[str]: """ Translate from axis or coord name to variable name @@ -283,7 +285,7 @@ def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: """ valid_keys = _COORD_NAMES + _AXIS_NAMES - if key not in valid_keys: + if error and key not in valid_keys: raise KeyError( f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}" ) @@ -313,6 +315,56 @@ def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: return list(results) +def _get_all(var: Union[DataArray, Dataset], key: str) -> List[str]: + + results = set(_get_axis_coord(var, key, error=False)) + for func in (_get_measure, _get_with_standard_name): + results.update(func(var, key)) + + return list(results) + + +def _get_dims(var: Union[DataArray, Dataset], key: str) -> List[str]: + results = set(_get_all(var, key)).intersection(var.dims) + return list(results) + + +def _get_single_dim(var: Union[DataArray, Dataset], key: str) -> List[str]: + return _get_single(var, key, _get_dims) + + +def _get_coords(var: Union[DataArray, Dataset], key: str) -> List[str]: + results = set(_get_all(var, key)).intersection(var.coords) + return list(results) + + +def _get_single_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: + return _get_single(var, key, _get_coords) + + +def _get_indexes(var: Union[DataArray, Dataset], key: str) -> List[str]: + results = set(_get_all(var, key)).intersection(var.indexes) + return list(results) + + +def _get_single_index(var: Union[DataArray, Dataset], key: str) -> List[str]: + return _get_single(var, key, _get_indexes) + + +def _get_single(var: Union[DataArray, Dataset], key: str, func): + + results = func(var, key) + + if len(results) > 1: + raise KeyError( + f"Multiple results for {key!r} found: {results!r}. I expected only one." + ) + elif len(results) == 0: + raise KeyError(f"No results found for {key!r}.") + + return results + + def _get_measure_variable( da: DataArray, key: str, error: bool = True, default: str = None ) -> List[DataArray]: @@ -374,34 +426,30 @@ def _get_with_standard_name( #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_get_axis_coord, _get_with_standard_name), - "dims": (_get_axis_coord, _get_with_standard_name), # transpose - "drop_dims": (_get_axis_coord, _get_with_standard_name), # drop_dims - "dimensions": (_get_axis_coord, _get_with_standard_name), # stack - "dims_dict": (_get_axis_coord, _get_with_standard_name), # swap_dims, rename_dims - "shifts": (_get_axis_coord, _get_with_standard_name), # shift, roll - "pad_width": (_get_axis_coord, _get_with_standard_name), # shift, roll - "names": ( - _get_axis_coord, - _get_measure, - _get_with_standard_name, - ), # set_coords, reset_coords, drop_vars - "labels": (_get_axis_coord, _get_measure, _get_with_standard_name), # drop - "coords": (_get_axis_coord, _get_with_standard_name), # interp - "indexers": (_get_axis_coord, _get_with_standard_name), # sel, isel, reindex + "dim": (_get_dims,), + "dims": (_get_dims,), # transpose + "drop_dims": (_get_dims,), # drop_dims + "dimensions": (_get_dims,), # stack + "dims_dict": (_get_dims,), # swap_dims, rename_dims + "shifts": (_get_dims,), # shift, roll + "pad_width": (_get_dims,), # shift, roll + "names": (_get_all,), # set_coords, reset_coords, drop_vars + "labels": (_get_all,), # drop + "coords": (_get_coords,), # interp + "indexers": (_get_indexes,), # sel, isel, reindex # "indexes": (_get_axis_coord,), # set_index - "dims_or_levels": (_get_axis_coord, _get_with_standard_name), # reset_index - "window": (_get_axis_coord, _get_with_standard_name), # rolling_exp - "coord": (_get_axis_coord_single,), # differentiate, integrate + "dims_or_levels": (_get_indexes,), # reset_index + "window": (_get_dims,), # rolling_exp + "coord": (_get_coords,), # differentiate, integrate "group": ( _get_axis_coord_single, _get_groupby_time_accessor, _get_with_standard_name, ), - "indexer": (_get_axis_coord_single,), # resample - "variables": (_get_axis_coord, _get_with_standard_name), # sortby + "indexer": (_get_single_index,), # resample + "variables": (_get_all,), # sortby "weights": (_get_measure_variable,), # type: ignore - "chunks": (_get_axis_coord, _get_with_standard_name), # chunk + "chunks": (_get_dims,), # chunk } @@ -907,7 +955,7 @@ def _rewrite_values( # these are valid for .sel, .isel, .coarsen all_mappers = ChainMap( key_mappers, - dict.fromkeys(var_kws, (_get_axis_coord, _get_with_standard_name)), + dict.fromkeys(var_kws, (_get_all,)), ) for key in set(all_mappers) & set(kwargs): From 365b7e800ed930b0c6993d48f37a409821bf0fb5 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 11 Feb 2021 12:21:17 -0700 Subject: [PATCH 02/22] better mappers draft --- cf_xarray/accessor.py | 73 +++++++++++++++++++++----------- cf_xarray/tests/test_accessor.py | 3 +- doc/whats-new.rst | 2 +- 3 files changed, 50 insertions(+), 28 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 01c2ce4d..6accdc56 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -372,36 +372,46 @@ def _get_with_standard_name( return varnames +def _get_all(obj: Union[DataArray, Dataset], key: Union[str, List[str]]) -> List[str]: + all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name, _get_measure) + results = apply_mapper(all_mappers, obj, key, error=False, default=None) + return results + + +def _get_dims(obj: Union[DataArray, Dataset], key: Union[str, List[str]]): + return [k for k in _get_all(obj, key) if k in obj.dims] + + +def _get_indexes(obj: Union[DataArray, Dataset], key: Union[str, List[str]]): + return [k for k in _get_all(obj, key) if k in obj.indexes] + + +def _get_coords(obj: Union[DataArray, Dataset], key: Union[str, List[str]]): + return [k for k in _get_all(obj, key) if k in obj.coords] + + #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_get_axis_coord, _get_with_standard_name), - "dims": (_get_axis_coord, _get_with_standard_name), # transpose - "drop_dims": (_get_axis_coord, _get_with_standard_name), # drop_dims - "dimensions": (_get_axis_coord, _get_with_standard_name), # stack - "dims_dict": (_get_axis_coord, _get_with_standard_name), # swap_dims, rename_dims - "shifts": (_get_axis_coord, _get_with_standard_name), # shift, roll - "pad_width": (_get_axis_coord, _get_with_standard_name), # shift, roll - "names": ( - _get_axis_coord, - _get_measure, - _get_with_standard_name, - ), # set_coords, reset_coords, drop_vars - "labels": (_get_axis_coord, _get_measure, _get_with_standard_name), # drop - "coords": (_get_axis_coord, _get_with_standard_name), # interp - "indexers": (_get_axis_coord, _get_with_standard_name), # sel, isel, reindex + "dim": (_get_dims,), # (_get_axis_coord, _get_with_standard_name), + "dims": (_get_dims,), # (_get_axis_coord, _get_with_standard_name), # transpose + "drop_dims": (_get_dims,), # drop_dims + "dimensions": (_get_dims,), # stack + "dims_dict": (_get_dims,), # swap_dims, rename_dims + "shifts": (_get_dims,), # shift, roll + "pad_width": (_get_dims,), # shift, roll + "names": (_get_all,), # set_coords, reset_coords, drop_vars + "labels": (_get_indexes,), # drop_sel + "coords": (_get_dims,), # interp + "indexers": (_get_indexes,), # sel, isel, reindex # "indexes": (_get_axis_coord,), # set_index - "dims_or_levels": (_get_axis_coord, _get_with_standard_name), # reset_index - "window": (_get_axis_coord, _get_with_standard_name), # rolling_exp + "dims_or_levels": (_get_dims,), # reset_index + "window": (_get_dims,), # rolling_exp "coord": (_get_axis_coord_single,), # differentiate, integrate - "group": ( - _get_axis_coord_single, - _get_groupby_time_accessor, - _get_with_standard_name, - ), - "indexer": (_get_axis_coord_single,), # resample - "variables": (_get_axis_coord, _get_with_standard_name), # sortby + "group": (_get_all, _get_groupby_time_accessor), # groupby + "indexer": (_get_indexes,), # resample + "variables": (_get_all,), # sortby "weights": (_get_measure_variable,), # type: ignore - "chunks": (_get_axis_coord, _get_with_standard_name), # chunk + "chunks": (_get_dims,), # chunk } @@ -430,6 +440,9 @@ def _build_docstring(func): can be used for arguments. """ + get_all_docstring = ( + f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r};\n\t\t\tor standard names" + ) # this list will need to be updated any time a new mapper is added mapper_docstrings = { _get_axis_coord: f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r}", @@ -437,6 +450,10 @@ def _build_docstring(func): _get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'", _get_with_standard_name: "Standard names", _get_measure_variable: f"One of {_CELL_MEASURES!r}", + _get_all: get_all_docstring, + _get_indexes: get_all_docstring + "present in .indexes", + _get_dims: get_all_docstring + "present in .dims", + _get_coords: get_all_docstring + "present in .coords", } sig = inspect.signature(func) @@ -1374,6 +1391,12 @@ def guess_coord_axis(self, verbose: bool = False) -> Union[DataArray, Dataset]: obj[var].attrs = dict(ChainMap(obj[var].attrs, ATTRS[axis])) return obj + def drop(self, *args, **kwargs): + raise NotImplementedError( + "cf-xarray does not support .drop." + "Please use .cf.drop_vars or .cf.drop_sel as appropriate." + ) + @xr.register_dataset_accessor("cf") class CFDatasetAccessor(CFAccessor): diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 2682ccf9..af8e5570 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -849,8 +849,7 @@ def test_standard_name_mapper(): @pytest.mark.parametrize("obj", objects) -@pytest.mark.parametrize("attr", ["drop", "drop_vars", "set_coords"]) -@pytest.mark.filterwarnings("ignore:dropping .* using `drop` .* deprecated") +@pytest.mark.parametrize("attr", ["drop_vars", "set_coords"]) def test_drop_vars_and_set_coords(obj, attr): # DataArray object has no attribute set_coords diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e6be6541..787d51fb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -9,7 +9,7 @@ v0.4.1 (unreleased) - Replace ``cf.describe()`` with :py:meth:`Dataset.cf.__repr__`. By `Mattia Almansi`_. - Automatically set ``x`` or ``y`` for :py:attr:`DataArray.cf.plot`. By `Deepak Cherian`_. - Added scripts to document :ref:`criteria` with tables. By `Mattia Almansi`_. -- Support for ``.drop()``, ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_. +- Support for ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_. - Support for using ``standard_name`` in more functions. (:pr:`128`) By `Deepak Cherian`_ - Allow :py:meth:`DataArray.cf.__getitem__` with standard names. By `Deepak Cherian`_ - Rewrite the ``values`` of :py:attr:`Dataset.coords` and :py:attr:`Dataset.data_vars` with objects returned From 6f0b2d4c65985f7f0eb8b6bd87bed273ac4fbc52 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Thu, 11 Feb 2021 21:27:20 +0000 Subject: [PATCH 03/22] remove malmans2 _get_funcs --- cf_xarray/accessor.py | 50 ++++++------------------------------------- 1 file changed, 7 insertions(+), 43 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index e0416dec..f9eaf6db 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -315,42 +315,6 @@ def _get_axis_coord( return list(results) -def _get_all(var: Union[DataArray, Dataset], key: str) -> List[str]: - - results = set(_get_axis_coord(var, key, error=False)) - for func in (_get_measure, _get_with_standard_name): - results.update(func(var, key)) - - return list(results) - - -def _get_dims(var: Union[DataArray, Dataset], key: str) -> List[str]: - results = set(_get_all(var, key)).intersection(var.dims) - return list(results) - - -def _get_single_dim(var: Union[DataArray, Dataset], key: str) -> List[str]: - return _get_single(var, key, _get_dims) - - -def _get_coords(var: Union[DataArray, Dataset], key: str) -> List[str]: - results = set(_get_all(var, key)).intersection(var.coords) - return list(results) - - -def _get_single_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: - return _get_single(var, key, _get_coords) - - -def _get_indexes(var: Union[DataArray, Dataset], key: str) -> List[str]: - results = set(_get_all(var, key)).intersection(var.indexes) - return list(results) - - -def _get_single_index(var: Union[DataArray, Dataset], key: str) -> List[str]: - return _get_single(var, key, _get_indexes) - - def _get_single(var: Union[DataArray, Dataset], key: str, func): results = func(var, key) @@ -424,28 +388,28 @@ def _get_with_standard_name( return varnames -def _get_all(obj: Union[DataArray, Dataset], key: Union[str, List[str]]) -> List[str]: - all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name, _get_measure) +def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: + all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name) results = apply_mapper(all_mappers, obj, key, error=False, default=None) return results -def _get_dims(obj: Union[DataArray, Dataset], key: Union[str, List[str]]): +def _get_dims(obj: Union[DataArray, Dataset], key: str): return [k for k in _get_all(obj, key) if k in obj.dims] -def _get_indexes(obj: Union[DataArray, Dataset], key: Union[str, List[str]]): +def _get_indexes(obj: Union[DataArray, Dataset], key: str): return [k for k in _get_all(obj, key) if k in obj.indexes] -def _get_coords(obj: Union[DataArray, Dataset], key: Union[str, List[str]]): +def _get_coords(obj: Union[DataArray, Dataset], key: str): return [k for k in _get_all(obj, key) if k in obj.coords] #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_get_dims,), # (_get_axis_coord, _get_with_standard_name), - "dims": (_get_dims,), # (_get_axis_coord, _get_with_standard_name), # transpose + "dim": (_get_dims,), + "dims": (_get_dims,), # transpose "drop_dims": (_get_dims,), # drop_dims "dimensions": (_get_dims,), # stack "dims_dict": (_get_dims,), # swap_dims, rename_dims From f46222655d58b80c1015b5d081680041f75414b8 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Fri, 12 Feb 2021 19:23:02 +0000 Subject: [PATCH 04/22] add single decorator --- cf_xarray/accessor.py | 43 ++++++++++++++++---------------- cf_xarray/tests/test_accessor.py | 1 + 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index f9eaf6db..a81792c1 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -3,6 +3,7 @@ import itertools import warnings from collections import ChainMap +from functools import wraps from typing import ( Any, Callable, @@ -249,9 +250,7 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List return [] -def _get_axis_coord( - var: Union[DataArray, Dataset], key: str, error: bool = True -) -> List[str]: +def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: """ Translate from axis or coord name to variable name @@ -285,7 +284,7 @@ def _get_axis_coord( """ valid_keys = _COORD_NAMES + _AXIS_NAMES - if error and key not in valid_keys: + if key not in valid_keys: raise KeyError( f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}" ) @@ -315,20 +314,6 @@ def _get_axis_coord( return list(results) -def _get_single(var: Union[DataArray, Dataset], key: str, func): - - results = func(var, key) - - if len(results) > 1: - raise KeyError( - f"Multiple results for {key!r} found: {results!r}. I expected only one." - ) - elif len(results) == 0: - raise KeyError(f"No results found for {key!r}.") - - return results - - def _get_measure_variable( da: DataArray, key: str, error: bool = True, default: str = None ) -> List[DataArray]: @@ -406,6 +391,21 @@ def _get_coords(obj: Union[DataArray, Dataset], key: str): return [k for k in _get_all(obj, key) if k in obj.coords] +def _single(func): + @wraps(func) + def get_single(obj: Union[DataArray, Dataset], key: str): + results = func(obj, key) + if len(results) > 1: + raise KeyError( + f"Multiple results for {key!r} found: {results!r}. I expected only one." + ) + elif len(results) == 0: + raise KeyError(f"No results found for {key!r}.") + return results + + return get_single + + #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { "dim": (_get_dims,), @@ -423,8 +423,8 @@ def _get_coords(obj: Union[DataArray, Dataset], key: str): "dims_or_levels": (_get_dims,), # reset_index "window": (_get_dims,), # rolling_exp "coord": (_get_axis_coord_single,), # differentiate, integrate - "group": (_get_all, _get_groupby_time_accessor), # groupby - "indexer": (_get_indexes,), # resample + "group": (_single(_get_all), _get_groupby_time_accessor), # groupby + "indexer": (_single(_get_indexes),), # resample "variables": (_get_all,), # sortby "weights": (_get_measure_variable,), # type: ignore "chunks": (_get_dims,), # chunk @@ -466,6 +466,7 @@ def _build_docstring(func): _get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'", _get_with_standard_name: "Standard names", _get_measure_variable: f"One of {_CELL_MEASURES!r}", + _single(_get_all): get_all_docstring.replace("One or more of", "One of"), _get_all: get_all_docstring, _get_indexes: get_all_docstring + "present in .indexes", _get_dims: get_all_docstring + "present in .dims", @@ -477,7 +478,7 @@ def _build_docstring(func): for k in set(sig.parameters.keys()) & set(_DEFAULT_KEY_MAPPERS): mappers = _DEFAULT_KEY_MAPPERS.get(k, []) docstring = ";\n\t\t\t".join( - mapper_docstrings.get(mapper, "unknown. please open an issue.") + mapper_docstrings.get(id(mapper), "unknown. please open an issue.") for mapper in mappers ) string += f"\t\t{k}: {docstring} \n" diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index af8e5570..4392261f 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -624,6 +624,7 @@ def test_get_bounds_dim_name(): def test_docstring(): + print(airds.cf.groupby.__doc__) assert "One of ('X'" in airds.cf.groupby.__doc__ assert "One or more of ('X'" in airds.cf.mean.__doc__ From be74e466e9c6d4577a791224349f3a67ee750d39 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Fri, 12 Feb 2021 19:25:31 +0000 Subject: [PATCH 05/22] remove id() --- cf_xarray/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index a81792c1..e59c4020 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -478,7 +478,7 @@ def _build_docstring(func): for k in set(sig.parameters.keys()) & set(_DEFAULT_KEY_MAPPERS): mappers = _DEFAULT_KEY_MAPPERS.get(k, []) docstring = ";\n\t\t\t".join( - mapper_docstrings.get(id(mapper), "unknown. please open an issue.") + mapper_docstrings.get(mapper, "unknown. please open an issue.") for mapper in mappers ) string += f"\t\t{k}: {docstring} \n" From 4cab0c801a35627d53f57ddd901ebe5529e44e09 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 15:54:05 +0000 Subject: [PATCH 06/22] fix docstring and more decorators --- cf_xarray/accessor.py | 117 +++++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index e59c4020..364d607c 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -3,7 +3,6 @@ import itertools import warnings from collections import ChainMap -from functools import wraps from typing import ( Any, Callable, @@ -250,6 +249,9 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List return [] +_get_groupby_time_accessor.__doc__ = "Time variable accessor e.g. 'T.month'" + + def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: """ Translate from axis or coord name to variable name @@ -314,16 +316,6 @@ def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: return list(results) -def _get_measure_variable( - da: DataArray, key: str, error: bool = True, default: str = None -) -> List[DataArray]: - """ tiny wrapper since xarray does not support providing str for weights.""" - varnames = apply_mapper(_get_measure, da, key, error, default) - if len(varnames) > 1: - raise KeyError(f"Multiple measures found for key {key!r}: {varnames!r}.") - return [da[varnames[0]]] - - def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]: """ Translate from cell measures to appropriate variable name. @@ -358,6 +350,11 @@ def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]: return list(results) +_get_measure.__doc__ = ( + f"One or more of {_CELL_MEASURES!r};" "\n\t\t\tor arbitraty measures" +) + + def _get_with_standard_name( obj: Union[DataArray, Dataset], name: Union[str, List[str]] ) -> List[str]: @@ -379,20 +376,49 @@ def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: return results -def _get_dims(obj: Union[DataArray, Dataset], key: str): - return [k for k in _get_all(obj, key) if k in obj.dims] +_get_all.__doc__ = ( + f"One or more of {(_AXIS_NAMES + _COORD_NAMES + _CELL_MEASURES)!r};" + "\n\t\t\tor arbitraty measures, or standard names" +) + + +def _dims(func): + @functools.wraps(func) + def get_dims(obj: Union[DataArray, Dataset], key: str): + return [k for k in func(obj, key) if k in obj.dims] + + get_dims.__doc__ = func.__doc__ + " present in .dims" + return get_dims + + +def _indexes(func): + @functools.wraps(func) + def get_indexes(obj: Union[DataArray, Dataset], key: str): + return [k for k in func(obj, key) if k in obj.dims] + + get_indexes.__doc__ = func.__doc__ + " present in .indexes" + return get_indexes + + +def _coords(func): + @functools.wraps(func) + def get_coords(obj: Union[DataArray, Dataset], key: str): + return [k for k in func(obj, key) if k in obj.coords] + get_coords.__doc__ = func.__doc__ + " present in .coords" + return get_coords -def _get_indexes(obj: Union[DataArray, Dataset], key: str): - return [k for k in _get_all(obj, key) if k in obj.indexes] +def _variables(func): + @functools.wraps(func) + def get_variables(obj: Union[DataArray, Dataset], key: str): + return [obj[k] for k in func(obj, key)] -def _get_coords(obj: Union[DataArray, Dataset], key: str): - return [k for k in _get_all(obj, key) if k in obj.coords] + return get_variables def _single(func): - @wraps(func) + @functools.wraps(func) def get_single(obj: Union[DataArray, Dataset], key: str): results = func(obj, key) if len(results) > 1: @@ -403,31 +429,33 @@ def get_single(obj: Union[DataArray, Dataset], key: str): raise KeyError(f"No results found for {key!r}.") return results + get_single.__doc__ = func.__doc__.replace("One or more of", "One of") + return get_single #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_get_dims,), - "dims": (_get_dims,), # transpose - "drop_dims": (_get_dims,), # drop_dims - "dimensions": (_get_dims,), # stack - "dims_dict": (_get_dims,), # swap_dims, rename_dims - "shifts": (_get_dims,), # shift, roll - "pad_width": (_get_dims,), # shift, roll + "dim": (_dims(_get_all),), + "dims": (_dims(_get_all),), # transpose + "drop_dims": (_dims(_get_all),), # drop_dims + "dimensions": (_dims(_get_all),), # stack + "dims_dict": (_dims(_get_all),), # swap_dims, rename_dims + "shifts": (_dims(_get_all),), # shift, roll + "pad_width": (_dims(_get_all),), # shift, roll "names": (_get_all,), # set_coords, reset_coords, drop_vars - "labels": (_get_indexes,), # drop_sel - "coords": (_get_dims,), # interp - "indexers": (_get_indexes,), # sel, isel, reindex + "labels": (_indexes(_get_all),), # drop_sel + "coords": (_dims(_get_all),), # interp + "indexers": (_dims(_get_all),), # sel, isel, reindex # "indexes": (_get_axis_coord,), # set_index - "dims_or_levels": (_get_dims,), # reset_index - "window": (_get_dims,), # rolling_exp - "coord": (_get_axis_coord_single,), # differentiate, integrate + "dims_or_levels": (_dims(_get_all),), # reset_index + "window": (_dims(_get_all),), # rolling_exp + "coord": (_single(_coords(_get_all)),), # differentiate, integrate "group": (_single(_get_all), _get_groupby_time_accessor), # groupby - "indexer": (_single(_get_indexes),), # resample + "indexer": (_single(_indexes(_get_all)),), # resample "variables": (_get_all,), # sortby - "weights": (_get_measure_variable,), # type: ignore - "chunks": (_get_dims,), # chunk + "weights": (_variables(_single(_get_all)),), # type: ignore + "chunks": (_dims(_get_all),), # chunk } @@ -456,36 +484,19 @@ def _build_docstring(func): can be used for arguments. """ - get_all_docstring = ( - f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r};\n\t\t\tor standard names" - ) - # this list will need to be updated any time a new mapper is added - mapper_docstrings = { - _get_axis_coord: f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r}", - _get_axis_coord_single: f"One of {(_AXIS_NAMES + _COORD_NAMES)!r}", - _get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'", - _get_with_standard_name: "Standard names", - _get_measure_variable: f"One of {_CELL_MEASURES!r}", - _single(_get_all): get_all_docstring.replace("One or more of", "One of"), - _get_all: get_all_docstring, - _get_indexes: get_all_docstring + "present in .indexes", - _get_dims: get_all_docstring + "present in .dims", - _get_coords: get_all_docstring + "present in .coords", - } - sig = inspect.signature(func) string = "" for k in set(sig.parameters.keys()) & set(_DEFAULT_KEY_MAPPERS): mappers = _DEFAULT_KEY_MAPPERS.get(k, []) docstring = ";\n\t\t\t".join( - mapper_docstrings.get(mapper, "unknown. please open an issue.") + mapper.__doc__ if mapper.__doc__ else "unknown. please open an issue." for mapper in mappers ) string += f"\t\t{k}: {docstring} \n" for param in sig.parameters: if sig.parameters[param].kind is inspect.Parameter.VAR_KEYWORD: - string += f"\t\t{param}: {mapper_docstrings[_get_axis_coord]} \n\n" + string += f"\t\t{param}: {_get_all.__doc__} \n\n" return ( f"\n\tThe following arguments will be processed by cf_xarray: \n{string}" "\n\t----\n\t" From f30e7fe62d18f3fa84ed27ac044b5d0559a6f761 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 16:58:28 +0000 Subject: [PATCH 07/22] back to undecorated --- cf_xarray/accessor.py | 60 +++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 364d607c..b4256d86 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -382,31 +382,25 @@ def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: ) -def _dims(func): - @functools.wraps(func) - def get_dims(obj: Union[DataArray, Dataset], key: str): - return [k for k in func(obj, key) if k in obj.dims] +def _get_dims(obj: Union[DataArray, Dataset], key: str): + return [k for k in _get_all(obj, key) if k in obj.dims] - get_dims.__doc__ = func.__doc__ + " present in .dims" - return get_dims +_get_dims.__doc__ = _get_all.__doc__ + " present in .dims" -def _indexes(func): - @functools.wraps(func) - def get_indexes(obj: Union[DataArray, Dataset], key: str): - return [k for k in func(obj, key) if k in obj.dims] - get_indexes.__doc__ = func.__doc__ + " present in .indexes" - return get_indexes +def _get_indexes(obj: Union[DataArray, Dataset], key: str): + return [k for k in _get_all(obj, key) if k in obj.indexes] -def _coords(func): - @functools.wraps(func) - def get_coords(obj: Union[DataArray, Dataset], key: str): - return [k for k in func(obj, key) if k in obj.coords] +_get_indexes.__doc__ = _get_all.__doc__ + " present in .indexes" + + +def _get_coords(obj: Union[DataArray, Dataset], key: str): + return [k for k in _get_all(obj, key) if k in obj.coords] + - get_coords.__doc__ = func.__doc__ + " present in .coords" - return get_coords +_get_coords.__doc__ = _get_all.__doc__ + " present in .coords" def _variables(func): @@ -436,26 +430,26 @@ def get_single(obj: Union[DataArray, Dataset], key: str): #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_dims(_get_all),), - "dims": (_dims(_get_all),), # transpose - "drop_dims": (_dims(_get_all),), # drop_dims - "dimensions": (_dims(_get_all),), # stack - "dims_dict": (_dims(_get_all),), # swap_dims, rename_dims - "shifts": (_dims(_get_all),), # shift, roll - "pad_width": (_dims(_get_all),), # shift, roll + "dim": (_get_dims,), + "dims": (_get_dims,), # transpose + "drop_dims": (_get_dims,), # drop_dims + "dimensions": (_get_dims,), # stack + "dims_dict": (_get_dims,), # swap_dims, rename_dims + "shifts": (_get_dims,), # shift, roll + "pad_width": (_get_dims,), # shift, roll "names": (_get_all,), # set_coords, reset_coords, drop_vars - "labels": (_indexes(_get_all),), # drop_sel - "coords": (_dims(_get_all),), # interp - "indexers": (_dims(_get_all),), # sel, isel, reindex + "labels": (_get_indexes,), # drop_sel + "coords": (_get_dims,), # interp + "indexers": (_get_dims,), # sel, isel, reindex # "indexes": (_get_axis_coord,), # set_index - "dims_or_levels": (_dims(_get_all),), # reset_index - "window": (_dims(_get_all),), # rolling_exp - "coord": (_single(_coords(_get_all)),), # differentiate, integrate + "dims_or_levels": (_get_dims,), # reset_index + "window": (_get_dims,), # rolling_exp + "coord": (_single(_get_coords),), # differentiate, integrate "group": (_single(_get_all), _get_groupby_time_accessor), # groupby - "indexer": (_single(_indexes(_get_all)),), # resample + "indexer": (_single(_get_indexes),), # resample "variables": (_get_all,), # sortby "weights": (_variables(_single(_get_all)),), # type: ignore - "chunks": (_dims(_get_all),), # chunk + "chunks": (_get_dims,), # chunk } From 7525a2a5472a603c1f6a1f88fdc0f59bd9b2dd28 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 18:24:04 +0000 Subject: [PATCH 08/22] replace _get_axis_coord_single --- cf_xarray/accessor.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index b4256d86..ce999158 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -203,18 +203,6 @@ def _apply_single_mapper(mapper): return results -def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str) -> List[str]: - """ Helper method for when we really want only one result per key. """ - results = _get_axis_coord(var, key) - if len(results) > 1: - raise KeyError( - f"Multiple results for {key!r} found: {results!r}. I expected only one." - ) - elif len(results) == 0: - raise KeyError(f"No results found for {key!r}.") - return results - - def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]: """ Helper method for when our key name is of the nature "T.month" and we want to @@ -826,7 +814,7 @@ def __call__(self, *args, **kwargs): obj=self._obj, attr="plot", accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_get_axis_coord_single,)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), ) return self._plot_decorator(plot)(*args, **kwargs) @@ -838,7 +826,7 @@ def __getattr__(self, attr): obj=self._obj.plot, attr=attr, accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_get_axis_coord_single,)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), # TODO: "extra_decorator" is more complex than I would like it to be. # Not sure if there is a better way though extra_decorator=self._plot_decorator, @@ -1001,7 +989,7 @@ def _rewrite_values( if vkw in kwargs: maybe_update = { # TODO: this is assuming key_mappers[k] is always - # _get_axis_coord_single + # _single(_get_all) k: apply_mapper( key_mappers[k], self._obj, v, error=False, default=[v] )[0] @@ -1357,8 +1345,8 @@ def rename_like( renamer = {} for key in good_keys: - ours = _get_axis_coord_single(self._obj, key)[0] - theirs = _get_axis_coord_single(other, key)[0] + ours = _single(_get_all)(self._obj, key)[0] + theirs = _single(_get_all)(other, key)[0] renamer[ours] = theirs newobj = self._obj.rename(renamer) @@ -1467,7 +1455,7 @@ def get_bounds(self, key: str) -> DataArray: DataArray """ name = apply_mapper( - _get_axis_coord_single, self._obj, key, error=False, default=[key] + _single(_get_all), self._obj, key, error=False, default=[key] )[0] bounds = self._obj[name].attrs["bounds"] obj = self._maybe_to_dataset() From 11b3ef7053780f30999538142df9673799daaa2d Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 18:57:42 +0000 Subject: [PATCH 09/22] use _get_coords for plot() --- cf_xarray/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index ce999158..82c8ef2b 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -814,7 +814,7 @@ def __call__(self, *args, **kwargs): obj=self._obj, attr="plot", accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_coords),)), ) return self._plot_decorator(plot)(*args, **kwargs) From f5ff3c9f34f054d70544567adab8a6a6de915c05 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 19:35:53 +0000 Subject: [PATCH 10/22] better type check --- cf_xarray/accessor.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 82c8ef2b..6565c842 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -15,7 +15,9 @@ Optional, Set, Tuple, + TypeVar, Union, + cast ) import xarray as xr @@ -145,6 +147,9 @@ # Type for Mapper functions Mapper = Callable[[Union[DataArray, Dataset], str], List[str]] +# Type for decorators +F = TypeVar("F", bound=Callable[..., Any]) + def apply_mapper( mappers: Union[Mapper, Tuple[Mapper, ...]], @@ -370,38 +375,38 @@ def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: ) -def _get_dims(obj: Union[DataArray, Dataset], key: str): +def _get_dims(obj: Union[DataArray, Dataset], key: str) -> List[str]: return [k for k in _get_all(obj, key) if k in obj.dims] _get_dims.__doc__ = _get_all.__doc__ + " present in .dims" -def _get_indexes(obj: Union[DataArray, Dataset], key: str): +def _get_indexes(obj: Union[DataArray, Dataset], key: str) -> List[str]: return [k for k in _get_all(obj, key) if k in obj.indexes] _get_indexes.__doc__ = _get_all.__doc__ + " present in .indexes" -def _get_coords(obj: Union[DataArray, Dataset], key: str): +def _get_coords(obj: Union[DataArray, Dataset], key: str) -> List[str]: return [k for k in _get_all(obj, key) if k in obj.coords] _get_coords.__doc__ = _get_all.__doc__ + " present in .coords" -def _variables(func): +def _variables(func: F) -> F: @functools.wraps(func) - def get_variables(obj: Union[DataArray, Dataset], key: str): + def wrapper(obj: Union[DataArray, Dataset], key: str) -> List[DataArray]: return [obj[k] for k in func(obj, key)] - return get_variables + return cast(F, wrapper) -def _single(func): +def _single(func: F) -> F: @functools.wraps(func) - def get_single(obj: Union[DataArray, Dataset], key: str): + def wrapper(obj: Union[DataArray, Dataset], key: str): results = func(obj, key) if len(results) > 1: raise KeyError( @@ -411,9 +416,9 @@ def get_single(obj: Union[DataArray, Dataset], key: str): raise KeyError(f"No results found for {key!r}.") return results - get_single.__doc__ = func.__doc__.replace("One or more of", "One of") + wrapper.__doc__ = func.__doc__.replace("One or more of", "One of") if func.__doc__ else func.__doc__ - return get_single + return cast(F, wrapper) #: Default mappers for common keys. From 8c91cf9dc3dc378ffd07a93466f992f0d226ec22 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 19:38:10 +0000 Subject: [PATCH 11/22] run pre-commit --- cf_xarray/accessor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 6565c842..ec074e1a 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -17,7 +17,7 @@ Tuple, TypeVar, Union, - cast + cast, ) import xarray as xr @@ -416,7 +416,11 @@ def wrapper(obj: Union[DataArray, Dataset], key: str): raise KeyError(f"No results found for {key!r}.") return results - wrapper.__doc__ = func.__doc__.replace("One or more of", "One of") if func.__doc__ else func.__doc__ + wrapper.__doc__ = ( + func.__doc__.replace("One or more of", "One of") + if func.__doc__ + else func.__doc__ + ) return cast(F, wrapper) From 49fe808766fc4b0b576248985dcc62b834f3e521 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sat, 20 Feb 2021 22:27:38 +0000 Subject: [PATCH 12/22] wrap doc and use coords for plot.* --- cf_xarray/accessor.py | 49 +++++++++++++++------------------ doc/examples/introduction.ipynb | 2 +- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index ec074e1a..322e02b0 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -208,6 +208,19 @@ def _apply_single_mapper(mapper): return results +def _set_doc(doc): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + wrapper.__doc__ = doc + return cast(F, wrapper) + + return decorator + + +@_set_doc("Time variable accessor e.g. 'T.month'") def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]: """ Helper method for when our key name is of the nature "T.month" and we want to @@ -231,9 +244,7 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List if "." in key: key, ext = key.split(".", 1) - results = apply_mapper( - (_get_axis_coord, _get_with_standard_name), var, key, error=False - ) + results = apply_mapper((_get_all,), var, key, error=False) if len(results) > 1: raise KeyError(f"Multiple results received for {key}.") return [v + "." + ext for v in results] @@ -242,9 +253,6 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List return [] -_get_groupby_time_accessor.__doc__ = "Time variable accessor e.g. 'T.month'" - - def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]: """ Translate from axis or coord name to variable name @@ -343,11 +351,6 @@ def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]: return list(results) -_get_measure.__doc__ = ( - f"One or more of {_CELL_MEASURES!r};" "\n\t\t\tor arbitraty measures" -) - - def _get_with_standard_name( obj: Union[DataArray, Dataset], name: Union[str, List[str]] ) -> List[str]: @@ -363,39 +366,31 @@ def _get_with_standard_name( return varnames +@_set_doc( + f"One or more of {(_AXIS_NAMES + _COORD_NAMES + _CELL_MEASURES)!r}," + "\n\t\t\tor arbitraty measures, or standard names" +) def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name) results = apply_mapper(all_mappers, obj, key, error=False, default=None) return results -_get_all.__doc__ = ( - f"One or more of {(_AXIS_NAMES + _COORD_NAMES + _CELL_MEASURES)!r};" - "\n\t\t\tor arbitraty measures, or standard names" -) - - +@_set_doc(_get_all.__doc__ + " present in .dims") def _get_dims(obj: Union[DataArray, Dataset], key: str) -> List[str]: return [k for k in _get_all(obj, key) if k in obj.dims] -_get_dims.__doc__ = _get_all.__doc__ + " present in .dims" - - +@_set_doc(_get_all.__doc__ + " present in .indexes") def _get_indexes(obj: Union[DataArray, Dataset], key: str) -> List[str]: return [k for k in _get_all(obj, key) if k in obj.indexes] -_get_indexes.__doc__ = _get_all.__doc__ + " present in .indexes" - - +@_set_doc(_get_all.__doc__ + " present in .coords") def _get_coords(obj: Union[DataArray, Dataset], key: str) -> List[str]: return [k for k in _get_all(obj, key) if k in obj.coords] -_get_coords.__doc__ = _get_all.__doc__ + " present in .coords" - - def _variables(func: F) -> F: @functools.wraps(func) def wrapper(obj: Union[DataArray, Dataset], key: str) -> List[DataArray]: @@ -835,7 +830,7 @@ def __getattr__(self, attr): obj=self._obj.plot, attr=attr, accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_coords),)), # TODO: "extra_decorator" is more complex than I would like it to be. # Not sure if there is a better way though extra_decorator=self._plot_decorator, diff --git a/doc/examples/introduction.ipynb b/doc/examples/introduction.ipynb index 3a7c1c94..87d0f3bb 100644 --- a/doc/examples/introduction.ipynb +++ b/doc/examples/introduction.ipynb @@ -982,7 +982,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.8" + "version": "3.9.1" }, "toc": { "base_numbering": 1, From 425c4720d83c97a17f509c3ce4977bcec301ef64 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sun, 21 Feb 2021 07:43:54 +0000 Subject: [PATCH 13/22] explicitly write docs --- cf_xarray/accessor.py | 44 +++++++++++++++----------------- cf_xarray/tests/test_accessor.py | 19 +++++++++++++- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 322e02b0..d46da8b0 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -208,20 +208,10 @@ def _apply_single_mapper(mapper): return results -def _set_doc(doc): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - wrapper.__doc__ = doc - return cast(F, wrapper) - - return decorator - - -@_set_doc("Time variable accessor e.g. 'T.month'") def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]: + """ + Time variable accessor e.g. 'T.month' + """ """ Helper method for when our key name is of the nature "T.month" and we want to isolate the "T" for coordinate mapping @@ -241,6 +231,7 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List ----- Returns an empty list if there is no frequency extension specified. """ + if "." in key: key, ext = key.split(".", 1) @@ -366,28 +357,37 @@ def _get_with_standard_name( return varnames -@_set_doc( - f"One or more of {(_AXIS_NAMES + _COORD_NAMES + _CELL_MEASURES)!r}," - "\n\t\t\tor arbitraty measures, or standard names" -) def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitraty measures, or standard names + """ all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name) results = apply_mapper(all_mappers, obj, key, error=False, default=None) return results -@_set_doc(_get_all.__doc__ + " present in .dims") def _get_dims(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitraty measures, or standard names present in .dims + """ return [k for k in _get_all(obj, key) if k in obj.dims] -@_set_doc(_get_all.__doc__ + " present in .indexes") def _get_indexes(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitraty measures, or standard names present in .indexes + """ return [k for k in _get_all(obj, key) if k in obj.indexes] -@_set_doc(_get_all.__doc__ + " present in .coords") def _get_coords(obj: Union[DataArray, Dataset], key: str) -> List[str]: + """ + One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', + 'area', 'volume'), or arbitraty measures, or standard names present in .coords + """ return [k for k in _get_all(obj, key) if k in obj.coords] @@ -830,7 +830,7 @@ def __getattr__(self, attr): obj=self._obj.plot, attr=attr, accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_single(_get_coords),)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), # TODO: "extra_decorator" is more complex than I would like it to be. # Not sure if there is a better way though extra_decorator=self._plot_decorator, @@ -992,8 +992,6 @@ def _rewrite_values( for vkw in var_kws: if vkw in kwargs: maybe_update = { - # TODO: this is assuming key_mappers[k] is always - # _single(_get_all) k: apply_mapper( key_mappers[k], self._obj, v, error=False, default=[v] )[0] diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 4392261f..1e697cbe 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -624,9 +624,26 @@ def test_get_bounds_dim_name(): def test_docstring(): - print(airds.cf.groupby.__doc__) assert "One of ('X'" in airds.cf.groupby.__doc__ + assert "Time variable accessor e.g. 'T.month'" in airds.cf.groupby.__doc__ assert "One or more of ('X'" in airds.cf.mean.__doc__ + assert "present in .dims" in airds.cf.drop_dims.__doc__ + assert "present in .coords" in airds.cf.integrate.__doc__ + assert "present in .indexes" in airds.cf.resample.__doc__ + + # Make sure docs are up to date + get_all_doc = cf_xarray.accessor._get_all.__doc__ + all_keys = ( + cf_xarray.accessor._AXIS_NAMES + + cf_xarray.accessor._COORD_NAMES + + cf_xarray.accessor._CELL_MEASURES + ) + expected = f"One or more of {all_keys!r}, or arbitraty measures, or standard names" + assert get_all_doc.split() == expected.split() + for name in ["dims", "indexes", "coords"]: + actual = getattr(cf_xarray.accessor, f"_get_{name}").__doc__ + expected = get_all_doc + f" present in .{name}" + assert actual.split() == expected.split() def _make_names(prefixes): From e8b673e2b702706301069efae46267c78cbc64af Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sun, 21 Feb 2021 08:00:15 +0000 Subject: [PATCH 14/22] fix typo --- cf_xarray/accessor.py | 8 ++++---- cf_xarray/tests/test_accessor.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index d46da8b0..447f5dd1 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -360,7 +360,7 @@ def _get_with_standard_name( def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: """ One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', - 'area', 'volume'), or arbitraty measures, or standard names + 'area', 'volume'), or arbitrary measures, or standard names """ all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name) results = apply_mapper(all_mappers, obj, key, error=False, default=None) @@ -370,7 +370,7 @@ def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]: def _get_dims(obj: Union[DataArray, Dataset], key: str) -> List[str]: """ One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', - 'area', 'volume'), or arbitraty measures, or standard names present in .dims + 'area', 'volume'), or arbitrary measures, or standard names present in .dims """ return [k for k in _get_all(obj, key) if k in obj.dims] @@ -378,7 +378,7 @@ def _get_dims(obj: Union[DataArray, Dataset], key: str) -> List[str]: def _get_indexes(obj: Union[DataArray, Dataset], key: str) -> List[str]: """ One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', - 'area', 'volume'), or arbitraty measures, or standard names present in .indexes + 'area', 'volume'), or arbitrary measures, or standard names present in .indexes """ return [k for k in _get_all(obj, key) if k in obj.indexes] @@ -386,7 +386,7 @@ def _get_indexes(obj: Union[DataArray, Dataset], key: str) -> List[str]: def _get_coords(obj: Union[DataArray, Dataset], key: str) -> List[str]: """ One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time', - 'area', 'volume'), or arbitraty measures, or standard names present in .coords + 'area', 'volume'), or arbitrary measures, or standard names present in .coords """ return [k for k in _get_all(obj, key) if k in obj.coords] diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 1e697cbe..1c0527b8 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -638,7 +638,7 @@ def test_docstring(): + cf_xarray.accessor._COORD_NAMES + cf_xarray.accessor._CELL_MEASURES ) - expected = f"One or more of {all_keys!r}, or arbitraty measures, or standard names" + expected = f"One or more of {all_keys!r}, or arbitrary measures, or standard names" assert get_all_doc.split() == expected.split() for name in ["dims", "indexes", "coords"]: actual = getattr(cf_xarray.accessor, f"_get_{name}").__doc__ From ad88f3f4b12acfa7cbdbd9aacc444b2ac046187c Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 21 Feb 2021 09:24:34 -0700 Subject: [PATCH 15/22] Fix line plots with .plot --- cf_xarray/accessor.py | 14 +++++++++----- cf_xarray/tests/test_accessor.py | 6 ++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 447f5dd1..c9dab76c 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -346,6 +346,9 @@ def _get_with_standard_name( obj: Union[DataArray, Dataset], name: Union[str, List[str]] ) -> List[str]: """ returns a list of variable names with standard name == name. """ + if name is None: + return [] + varnames = [] if isinstance(obj, DataArray): obj = obj._to_temp_dataset() @@ -769,9 +772,9 @@ def __init__(self, obj, accessor): def _plot_decorator(self, func): """ This decorator is used to set default kwargs on plotting functions. - - For now, this is setting ``xincrease`` and ``yincrease``. It could set - other arguments in the future. + For now, this can + 1. set ``xincrease`` and ``yincrease``. + 2. automatically set ``x`` or ``y``. """ valid_keys = self.accessor.keys() @@ -795,7 +798,8 @@ def _process_x_or_y(kwargs, key): return kwargs is_line_plot = (func.__name__ == "line") or ( - func.__name__ == "wrapper" and kwargs.get("hue") + func.__name__ == "wrapper" + and (kwargs.get("hue") or self._obj.ndim == 1) ) if is_line_plot: if not kwargs.get("hue"): @@ -818,7 +822,7 @@ def __call__(self, *args, **kwargs): obj=self._obj, attr="plot", accessor=self.accessor, - key_mappers=dict.fromkeys(self._keys, (_single(_get_coords),)), + key_mappers=dict.fromkeys(self._keys, (_single(_get_all),)), ) return self._plot_decorator(plot)(*args, **kwargs) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 1c0527b8..5fb34bbb 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -413,6 +413,10 @@ def test_dataarray_plot(): np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data) plt.close() + rv = obj.isel(time=0, lon=1).cf.plot(x="lat") + np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data) + plt.close() + # various line plots and automatic guessing rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line() np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data) @@ -865,6 +869,8 @@ def test_standard_name_mapper(): expected = da.sortby("label") assert_identical(actual, expected) + assert cf_xarray.accessor._get_with_standard_name(da, None) == [] + @pytest.mark.parametrize("obj", objects) @pytest.mark.parametrize("attr", ["drop_vars", "set_coords"]) From 54b8cc6782e21d06776bd8e2b3786bff8dc3e89b Mon Sep 17 00:00:00 2001 From: malmans2 Date: Sun, 21 Feb 2021 18:03:03 +0000 Subject: [PATCH 16/22] replace _get_axis_coord --- cf_xarray/accessor.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index c9dab76c..ed5f982e 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -436,7 +436,7 @@ def wrapper(obj: Union[DataArray, Dataset], key: str): "labels": (_get_indexes,), # drop_sel "coords": (_get_dims,), # interp "indexers": (_get_dims,), # sel, isel, reindex - # "indexes": (_get_axis_coord,), # set_index + "indexes": (_get_dims,), # set_index TODO: cf_xarray decodes keys, not values "dims_or_levels": (_get_dims,), # reset_index "window": (_get_dims,), # rolling_exp "coord": (_single(_get_coords),), # differentiate, integrate @@ -617,12 +617,12 @@ def check_results(names, k): successful = dict.fromkeys(key, False) for k in key: if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES: - names = _get_axis_coord(obj, k) + names = _get_all(obj, k) check_results(names, k) successful[k] = bool(names) coords.extend(names) elif "measures" not in skip and k in accessor._get_all_cell_measures(): - measure = _get_measure(obj, k) + measure = _get_all(obj, k) check_results(measure, k) successful[k] = bool(measure) if measure: @@ -1145,10 +1145,7 @@ def axes(self) -> Dict[str, List[str]]: Dictionary of valid Axis names that can be used with ``__getitem__`` or ``.cf[key]``. Will be ("X", "Y", "Z", "T") or a subset thereof. """ - vardict = { - key: apply_mapper(_get_axis_coord, self._obj, key, error=False) - for key in _AXIS_NAMES - } + vardict = {key: _get_coords(self._obj, key) for key in _AXIS_NAMES} return {k: sorted(v) for k, v in vardict.items() if v} @@ -1167,10 +1164,7 @@ def coordinates(self) -> Dict[str, List[str]]: Dictionary of valid Coordinate names that can be used with ``__getitem__`` or ``.cf[key]``. Will be ("longitude", "latitude", "vertical", "time") or a subset thereof. """ - vardict = { - key: apply_mapper(_get_axis_coord, self._obj, key, error=False) - for key in _COORD_NAMES - } + vardict = {key: _get_coords(self._obj, key) for key in _COORD_NAMES} return {k: sorted(v) for k, v in vardict.items() if v} @@ -1197,10 +1191,10 @@ def cell_measures(self) -> Dict[str, List[str]]: da.attrs.get("cell_measures", "") for da in obj.data_vars.values() ] - measures: Dict[str, List[str]] = {} + keys = {} for attr in all_attrs: - for key, value in parse_cell_methods_attr(attr).items(): - measures[key] = measures.setdefault(key, []) + [value] + keys.update(parse_cell_methods_attr(attr)) + measures = {key: _get_all(self._obj, key) for key in keys} return {k: sorted(set(v)) for k, v in measures.items() if v} @@ -1638,7 +1632,7 @@ def decode_vertical_coords(self, prefix="z"): import re ds = self._obj - dims = _get_axis_coord(ds, "Z") + dims = _get_dims(ds, "Z") requirements = { "ocean_s_coordinate_g1": {"depth_c", "depth", "s", "C", "eta"}, From a3178b501fb0617f02809fd358773b2889787612 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Tue, 23 Feb 2021 08:26:06 +0000 Subject: [PATCH 17/22] update properties dosctring --- cf_xarray/accessor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index ed5f982e..9d0f0b2b 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1138,7 +1138,7 @@ def axes(self) -> Dict[str, List[str]]: This is useful for checking whether a key is valid for indexing, i.e. that the attributes necessary to allow indexing by that key exist. - However, it will only return the Axis names, not Coordinate names. + However, it will only return the Axis names present in ``.coords``, not Coordinate names. Returns ------- @@ -1157,7 +1157,7 @@ def coordinates(self) -> Dict[str, List[str]]: This is useful for checking whether a key is valid for indexing, i.e. that the attributes necessary to allow indexing by that key exist. - However, it will only return the Coordinate names, not Axis names. + However, it will only return the Coordinate names present in ``.coords``, not Axis names. Returns ------- From 7991beeff89191ec991c66402f18dd9db5565986 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Tue, 23 Feb 2021 09:01:01 +0000 Subject: [PATCH 18/22] Copy over #165 tests --- cf_xarray/accessor.py | 7 +------ cf_xarray/datasets.py | 3 ++- cf_xarray/tests/test_accessor.py | 23 ++++++++++++++++++++--- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 9d0f0b2b..f82cfbfa 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1337,12 +1337,7 @@ def rename_like( ourkeys = self.keys() theirkeys = other.cf.keys() - good_keys = set(_COORD_NAMES) & ourkeys & theirkeys - if not good_keys: - raise ValueError( - "No common coordinate variables between these two objects." - ) - + good_keys = ourkeys & theirkeys renamer = {} for key in good_keys: ours = _single(_get_all)(self._obj, key)[0] diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index fea548bb..5e6df446 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -70,6 +70,7 @@ anc["q_detection_limit"] = xr.DataArray( 1e-3, attrs=dict(standard_name="specific_humidity detection_minimum", units="g/g") ) +anc multiple = xr.Dataset() @@ -121,7 +122,7 @@ romsds["temp"] = ( ("ocean_time", "s_rho"), [np.linspace(20, 30, 30)] * 2, - {"coordinates": "z_rho_dummy"}, + {"coordinates": "z_rho_dummy", "standard_name": "sea_water_potential_temperature"}, ) romsds["temp"].encoding["coordinates"] = "s_rho" romsds.coords["z_rho_dummy"] = ( diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 5fb34bbb..6a57b449 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -215,10 +215,13 @@ def test_getitem_ancillary_variables(): def test_rename_like(): original = popds.copy(deep=True) - with pytest.raises(KeyError): - popds.cf.rename_like(airds) + # it'll match for axis: X (lon, nlon) and coordinate="longitude" (lon, TLONG) + # so delete the axis attributes + newair = airds.copy(deep=True) + del newair.lon.attrs["axis"] + del newair.lat.attrs["axis"] - renamed = popds.cf["TEMP"].cf.rename_like(airds) + renamed = popds.cf["TEMP"].cf.rename_like(newair) for k in ["TLONG", "TLAT"]: assert k not in renamed.coords assert k in original.coords @@ -228,6 +231,20 @@ def test_rename_like(): assert "lat" in renamed.coords assert renamed.attrs["coordinates"] == "lon lat" + # standard name matching + newroms = romsds.expand_dims(latitude=[1], longitude=[1]).cf.guess_coord_axis() + renamed = popds.cf["UVEL"].cf.rename_like(newroms) + assert renamed.attrs["coordinates"] == "longitude latitude" + assert "longitude" in renamed.coords + assert "latitude" in renamed.coords + assert "ULON" not in renamed.coords + assert "ULAT" not in renamed.coords + + # should change "temp" to "TEMP" + renamed = romsds.cf.rename_like(popds) + assert "temp" not in renamed + assert "TEMP" in renamed + @pytest.mark.parametrize("obj", objects) @pytest.mark.parametrize( From 2211c97c6ddd2daf916945e44f77be80ed652d97 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Tue, 23 Feb 2021 09:16:42 +0000 Subject: [PATCH 19/22] remove anc --- cf_xarray/datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index 5e6df446..de593471 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -70,7 +70,6 @@ anc["q_detection_limit"] = xr.DataArray( 1e-3, attrs=dict(standard_name="specific_humidity detection_minimum", units="g/g") ) -anc multiple = xr.Dataset() From bc2b14d08a860e30efb3db12891d25f5231a2747 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Tue, 23 Feb 2021 12:03:30 +0000 Subject: [PATCH 20/22] comment out indexes --- cf_xarray/accessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index f82cfbfa..3b6e7b5d 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -436,7 +436,7 @@ def wrapper(obj: Union[DataArray, Dataset], key: str): "labels": (_get_indexes,), # drop_sel "coords": (_get_dims,), # interp "indexers": (_get_dims,), # sel, isel, reindex - "indexes": (_get_dims,), # set_index TODO: cf_xarray decodes keys, not values + # "indexes": (_single(_get_dims),), # set_index this decodes keys but not values "dims_or_levels": (_get_dims,), # reset_index "window": (_get_dims,), # rolling_exp "coord": (_single(_get_coords),), # differentiate, integrate From 29984e7487293bec43032a9a9521b4602f08aa95 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Tue, 23 Feb 2021 17:20:08 +0000 Subject: [PATCH 21/22] test at least once new functions --- cf_xarray/accessor.py | 2 +- cf_xarray/tests/test_accessor.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 3b6e7b5d..739fad8e 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -173,7 +173,7 @@ def _apply_single_mapper(mapper): try: results = mapper(obj, key) except KeyError as e: - if error: + if error or "I expected only one." in repr(e): raise e else: results = [] diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 6a57b449..148f6aed 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -932,12 +932,33 @@ def test_drop_sel_and_reset_coords(obj): @pytest.mark.parametrize("ds", datasets) def test_drop_dims(ds): + # Testing _get_dims + + # Add data_var and coord to test _get_dims + ds["lon_var"] = ds["lon"] + ds = ds.assign_coords(lon_coord=ds["lon"]) # Axis and coordinate for cf_name in ["X", "longitude"]: assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name)) +@pytest.mark.parametrize("ds", datasets) +def test_differentiate(ds): + # Testing _single(_get_coords) + + # Add data_var and coord to test _get_dims + ds["lon_var"] = ds["lon"] + ds = ds.assign_coords(lon_coord=ds["lon"]) + + # Axis + assert_identical(ds.differentiate("lon"), ds.cf.differentiate("lon")) + + # Multiple keys + with pytest.raises(KeyError, match=".*I expected only one."): + assert_identical(ds.differentiate("lon"), ds.cf.differentiate("X")) + + def test_new_standard_name_mappers(): assert_identical(forecast.cf.mean("realization"), forecast.mean("M")) assert_identical( From 035c8e83c8a0f0379d2a3c494fbb98343917b59d Mon Sep 17 00:00:00 2001 From: malmans2 Date: Tue, 23 Feb 2021 17:24:04 +0000 Subject: [PATCH 22/22] fix comments --- cf_xarray/tests/test_accessor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 148f6aed..e54e2199 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -932,7 +932,6 @@ def test_drop_sel_and_reset_coords(obj): @pytest.mark.parametrize("ds", datasets) def test_drop_dims(ds): - # Testing _get_dims # Add data_var and coord to test _get_dims ds["lon_var"] = ds["lon"] @@ -945,16 +944,15 @@ def test_drop_dims(ds): @pytest.mark.parametrize("ds", datasets) def test_differentiate(ds): - # Testing _single(_get_coords) - # Add data_var and coord to test _get_dims + # Add data_var and coord to test _get_coords ds["lon_var"] = ds["lon"] ds = ds.assign_coords(lon_coord=ds["lon"]) - # Axis + # Coordinate assert_identical(ds.differentiate("lon"), ds.cf.differentiate("lon")) - # Multiple keys + # Multiple coords (test error raised by _single) with pytest.raises(KeyError, match=".*I expected only one."): assert_identical(ds.differentiate("lon"), ds.cf.differentiate("X"))