From 62018a7274dfcd70bc5654a5c3822d62c0bf3bab Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 19 Dec 2022 22:06:46 +0100 Subject: [PATCH 1/9] complex cov --- xarray/core/computation.py | 28 ++++++++++++++++++---------- xarray/tests/test_computation.py | 6 ++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c0aa36aa3d2..fc4fb219477 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -15,6 +15,7 @@ Callable, Hashable, Iterable, + Literal, Mapping, Sequence, TypeVar, @@ -1217,7 +1218,9 @@ def apply_ufunc( return apply_array_ufunc(func, *args, dask=dask) -def cov(da_a, da_b, dim=None, ddof=1): +def cov( + da_a: T_DataArray, da_b: T_DataArray, dim: Hashable | None = None, ddof: int = 1 +) -> T_DataArray: """ Compute covariance between two DataArray objects along a shared dimension. @@ -1227,9 +1230,9 @@ def cov(da_a, da_b, dim=None, ddof=1): Array to compute. da_b : DataArray Array to compute. - dim : str, optional + dim : Hashable, optional The dimension along which the covariance will be computed - ddof : int, optional + ddof : int, default: 1 If ddof=1, covariance is normalized by N-1, giving an unbiased estimate, else normalization is by N. @@ -1297,7 +1300,9 @@ def cov(da_a, da_b, dim=None, ddof=1): return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov") -def corr(da_a, da_b, dim=None): +def corr( + da_a: T_DataArray, da_b: T_DataArray, dim: Hashable | None = None +) -> T_DataArray: """ Compute the Pearson correlation coefficient between two DataArray objects along a shared dimension. @@ -1308,7 +1313,7 @@ def corr(da_a, da_b, dim=None): Array to compute. da_b : DataArray Array to compute. - dim : str, optional + dim : Hashable, optional The dimension along which the correlation will be computed Returns @@ -1376,7 +1381,11 @@ def corr(da_a, da_b, dim=None): def _cov_corr( - da_a: T_DataArray, da_b: T_DataArray, dim=None, ddof=0, method=None + da_a: T_DataArray, + da_b: T_DataArray, + dim: Hashable | None = None, + ddof: int = 0, + method: Literal["cov", "corr", None] = None, ) -> T_DataArray: """ Internal method for xr.cov() and xr.corr() so only have to @@ -1396,12 +1405,11 @@ def _cov_corr( demeaned_da_b = da_b - da_b.mean(dim=dim) # 4. Compute covariance along the given dim - # # N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g. # Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) - cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim, skipna=True, min_count=1) / ( - valid_count - ) + cov = (demeaned_da_a.conjugate() * demeaned_da_b).sum( + dim=dim, skipna=True, min_count=1 + ) / (valid_count) if method == "cov": return cov diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index d70fd9d0d8d..78df480b446 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1598,6 +1598,12 @@ def test_autocov(da_a, dim) -> None: assert_allclose(actual, expected) +def test_complex_cov() -> None: + da = xr.DataArray([1j, -1j]) + actual = xr.cov(da, da) + assert abs(actual.item()) == 2 + + @requires_dask def test_vectorize_dask_new_output_dims() -> None: # regression test for GH3574 From 4778b38e378d5d744a186993e42564b0591cfba1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 19 Dec 2022 22:21:52 +0100 Subject: [PATCH 2/9] fix mypy --- xarray/core/computation.py | 5 +++-- xarray/tests/test_computation.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index fc4fb219477..a8d04e5231e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1391,6 +1391,7 @@ def _cov_corr( Internal method for xr.cov() and xr.corr() so only have to sanitize the input arrays once and we don't repeat code. """ + dim = None if dim is None else (dim,) # 1. Broadcast the two arrays da_a, da_b = align(da_a, da_b, join="inner", copy=False) @@ -1412,14 +1413,14 @@ def _cov_corr( ) / (valid_count) if method == "cov": - return cov + return cov # type: ignore[return-value] else: # compute std + corr da_a_std = da_a.std(dim=dim) da_b_std = da_b.std(dim=dim) corr = cov / (da_a_std * da_b_std) - return corr + return corr # type: ignore[return-value] def cross( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 78df480b446..898d8da61ec 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1378,7 +1378,7 @@ def test_vectorize_exclude_dims_dask() -> None: def test_corr_only_dataarray() -> None: with pytest.raises(TypeError, match="Only xr.DataArray is supported"): - xr.corr(xr.Dataset(), xr.Dataset()) + xr.corr(xr.Dataset(), xr.Dataset()) # type: ignore[type-var] def arrays_w_tuples(): From 498ae092251f024d26abaec8b2a5f241633c78ce Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 19 Dec 2022 22:23:56 +0100 Subject: [PATCH 3/9] update whatsa-new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cd5ecd83978..5b69c84c459 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,6 +45,8 @@ Bug fixes - add a ``keep_attrs`` parameter to :py:meth:`Dataset.pad`, :py:meth:`DataArray.pad`, and :py:meth:`Variable.pad` (:pull:`7267`). By `Justus Magin `_. +- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` for complex valued arrays (:issue:`7340`, :pull:`7392`). + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ From cc2edd2f94e6df412a7417b05ca6284fa8006e35 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 15 Jan 2023 11:09:49 -0700 Subject: [PATCH 4/9] Update xarray/core/computation.py --- xarray/core/computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a8d04e5231e..0b7d012c54e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1408,7 +1408,7 @@ def _cov_corr( # 4. Compute covariance along the given dim # N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g. # Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) - cov = (demeaned_da_a.conjugate() * demeaned_da_b).sum( + cov = (demeaned_da_a.conj() * demeaned_da_b).sum( dim=dim, skipna=True, min_count=1 ) / (valid_count) From d9d4098e6a4c1cd310480cc7a233f7f2d5cd0c7c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 10 Feb 2023 22:17:07 +0100 Subject: [PATCH 5/9] slight improvements to tests --- xarray/tests/test_computation.py | 86 ++++++++++++++++---------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 42466e3f184..8e45a07d6a2 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1390,14 +1390,15 @@ def test_corr_only_dataarray() -> None: xr.corr(xr.Dataset(), xr.Dataset()) # type: ignore[type-var] -def arrays_w_tuples(): +@pytest.fixture(scope="module") +def arrays(): da = xr.DataArray( np.random.random((3, 21, 4)), coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, dims=("a", "time", "x"), ) - arrays = [ + return [ da.isel(time=range(0, 18)), da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(), xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), @@ -1405,7 +1406,10 @@ def arrays_w_tuples(): xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]), ] - array_tuples = [ + +@pytest.fixture(scope="module") +def array_tuples(arrays): + return [ (arrays[0], arrays[0]), (arrays[0], arrays[1]), (arrays[1], arrays[1]), @@ -1417,27 +1421,19 @@ def arrays_w_tuples(): (arrays[4], arrays[4]), ] - return arrays, array_tuples - @pytest.mark.parametrize("ddof", [0, 1]) -@pytest.mark.parametrize( - "da_a, da_b", - [ - arrays_w_tuples()[1][3], - arrays_w_tuples()[1][4], - arrays_w_tuples()[1][5], - arrays_w_tuples()[1][6], - arrays_w_tuples()[1][7], - arrays_w_tuples()[1][8], - ], -) +@pytest.mark.parametrize("n", [3, 4, 5, 6, 7, 8]) @pytest.mark.parametrize("dim", [None, "x", "time"]) @requires_dask -def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None: +def test_lazy_corrcov( + n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: # GH 5284 from dask import is_dask_collection + da_a, da_b = array_tuples[n] + with raise_if_dask_computes(): cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof) assert is_dask_collection(cov) @@ -1447,12 +1443,13 @@ def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None: @pytest.mark.parametrize("ddof", [0, 1]) -@pytest.mark.parametrize( - "da_a, da_b", - [arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]], -) +@pytest.mark.parametrize("n", [0, 1, 2]) @pytest.mark.parametrize("dim", [None, "time"]) -def test_cov(da_a, da_b, dim, ddof) -> None: +def test_cov( + n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] + if dim is not None: def np_cov_ind(ts1, ts2, a, x): @@ -1499,12 +1496,13 @@ def np_cov(ts1, ts2): assert_allclose(actual, expected) -@pytest.mark.parametrize( - "da_a, da_b", - [arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]], -) +@pytest.mark.parametrize("n", [0, 1, 2]) @pytest.mark.parametrize("dim", [None, "time"]) -def test_corr(da_a, da_b, dim) -> None: +def test_corr( + n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] + if dim is not None: def np_corr_ind(ts1, ts2, a, x): @@ -1547,12 +1545,12 @@ def np_corr(ts1, ts2): assert_allclose(actual, expected) -@pytest.mark.parametrize( - "da_a, da_b", - arrays_w_tuples()[1], -) +@pytest.mark.parametrize("n", range(9)) @pytest.mark.parametrize("dim", [None, "time", "x"]) -def test_covcorr_consistency(da_a, da_b, dim) -> None: +def test_covcorr_consistency( + n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] # Testing that xr.corr and xr.cov are consistent with each other # 1. Broadcast the two arrays da_a, da_b = broadcast(da_a, da_b) @@ -1569,10 +1567,13 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None: @requires_dask -@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1]) +@pytest.mark.parametrize("n", range(9)) @pytest.mark.parametrize("dim", [None, "time", "x"]) @pytest.mark.filterwarnings("ignore:invalid value encountered in .*divide") -def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None: +def test_corr_lazycorr_consistency( + n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray] +) -> None: + da_a, da_b = array_tuples[n] da_al = da_a.chunk() da_bl = da_b.chunk() c_abl = xr.corr(da_al, da_bl, dim=dim) @@ -1591,19 +1592,18 @@ def test_corr_dtype_error(): xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk())) -@pytest.mark.parametrize( - "da_a", - arrays_w_tuples()[0], -) +@pytest.mark.parametrize("n", range(5)) @pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]]) -def test_autocov(da_a, dim) -> None: +def test_autocov(n: int, dim: str | None, arrays) -> None: + da = arrays[n] + # Testing that the autocovariance*(N-1) is ~=~ to the variance matrix # 1. Ignore the nans - valid_values = da_a.notnull() + valid_values = da.notnull() # Because we're using ddof=1, this requires > 1 value in each sample - da_a = da_a.where(valid_values.sum(dim=dim) > 1) - expected = ((da_a - da_a.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1) - actual = xr.cov(da_a, da_a, dim=dim) * (valid_values.sum(dim) - 1) + da = da.where(valid_values.sum(dim=dim) > 1) + expected = ((da - da.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1) + actual = xr.cov(da, da, dim=dim) * (valid_values.sum(dim) - 1) assert_allclose(actual, expected) From 73d6c789d883e92df6b7b5d5c5e02e2259491dd9 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 10 Feb 2023 22:18:08 +0100 Subject: [PATCH 6/9] bugfix in corr_cov for multiple dims --- xarray/core/computation.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 4a8ef617416..2305e753cee 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -30,7 +30,7 @@ from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.pycompat import is_duck_dask_array -from xarray.core.types import T_DataArray +from xarray.core.types import Dims, T_DataArray from xarray.core.utils import is_dict_like, is_scalar from xarray.core.variable import Variable @@ -1219,7 +1219,7 @@ def apply_ufunc( def cov( - da_a: T_DataArray, da_b: T_DataArray, dim: Hashable | None = None, ddof: int = 1 + da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1 ) -> T_DataArray: """ Compute covariance between two DataArray objects along a shared dimension. @@ -1230,7 +1230,7 @@ def cov( Array to compute. da_b : DataArray Array to compute. - dim : Hashable, optional + dim : str, iterable of hashable, "..." or None, optional The dimension along which the covariance will be computed ddof : int, default: 1 If ddof=1, covariance is normalized by N-1, giving an unbiased estimate, @@ -1300,9 +1300,7 @@ def cov( return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov") -def corr( - da_a: T_DataArray, da_b: T_DataArray, dim: Hashable | None = None -) -> T_DataArray: +def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray: """ Compute the Pearson correlation coefficient between two DataArray objects along a shared dimension. @@ -1313,7 +1311,7 @@ def corr( Array to compute. da_b : DataArray Array to compute. - dim : Hashable, optional + dim : str, iterable of hashable, "..." or None, optional The dimension along which the correlation will be computed Returns @@ -1383,7 +1381,7 @@ def corr( def _cov_corr( da_a: T_DataArray, da_b: T_DataArray, - dim: Hashable | None = None, + dim: Dims = None, ddof: int = 0, method: Literal["cov", "corr", None] = None, ) -> T_DataArray: @@ -1391,7 +1389,6 @@ def _cov_corr( Internal method for xr.cov() and xr.corr() so only have to sanitize the input arrays once and we don't repeat code. """ - dim = None if dim is None else (dim,) # 1. Broadcast the two arrays da_a, da_b = align(da_a, da_b, join="inner", copy=False) @@ -1633,7 +1630,7 @@ def cross( def dot( *arrays, - dims: str | Iterable[Hashable] | ellipsis | None = None, + dims: Dims = None, **kwargs: Any, ): """Generalized dot product for xarray objects. Like np.einsum, but @@ -1643,7 +1640,7 @@ def dot( ---------- *arrays : DataArray or Variable Arrays to compute. - dims : ..., str or tuple of str, optional + dims : str, iterable of hashable, "..." or None, optional Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. If not specified, then all the common dimensions are summed over. **kwargs : dict From 5f766b16eddfc08196d97695f0b42adaccf37a94 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 10 Feb 2023 22:19:36 +0100 Subject: [PATCH 7/9] fix whats-new --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ce112d243c4..c6615278852 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,8 @@ v2023.03.0 (unreleased) New Features ~~~~~~~~~~~~ +- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`). + By `Michael Niklas `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -128,8 +130,6 @@ Bug fixes - add a ``keep_attrs`` parameter to :py:meth:`Dataset.pad`, :py:meth:`DataArray.pad`, and :py:meth:`Variable.pad` (:pull:`7267`). By `Justus Magin `_. -- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` for complex valued arrays (:issue:`7340`, :pull:`7392`). - By `Michael Niklas `_. - Fixed performance regression in alignment between indexed and non-indexed objects of the same shape (:pull:`7382`). By `BenoƮt Bovy `_. From 576692bce54b3ae9bc4af79646a0168587002379 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 12 Feb 2023 17:02:54 +0100 Subject: [PATCH 8/9] allow refreshing of backends --- xarray/backends/cfgrib_.py | 5 +++-- xarray/backends/common.py | 29 +++++++++++++++++++++++++++-- xarray/backends/h5netcdf_.py | 5 +++-- xarray/backends/netCDF4_.py | 5 +++-- xarray/backends/plugins.py | 28 ++++++++++++++++++++-------- xarray/backends/pseudonetcdf_.py | 5 +++-- xarray/backends/pydap_.py | 5 +++-- xarray/backends/pynio_.py | 5 +++-- xarray/backends/scipy_.py | 5 +++-- xarray/backends/zarr.py | 5 +++-- 10 files changed, 71 insertions(+), 26 deletions(-) diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 4c7d6a65e8f..e75119cb1aa 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -9,7 +9,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - BackendEntrypoint, + _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.locks import SerializableLock, ensure_lock @@ -90,7 +90,8 @@ def get_encoding(self): return {"unlimited_dims": {k for k, v in dims.items() if v is None}} -class CfgribfBackendEntrypoint(BackendEntrypoint): +class CfgribfBackendEntrypoint(_InternalBackendEntrypoint): + _module_name = "cfgrib" available = module_available("cfgrib") def guess_can_open(self, filename_or_obj): diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 050493e3034..27411b58062 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -12,7 +12,12 @@ from xarray.conventions import cf_encoder from xarray.core import indexing from xarray.core.pycompat import is_duck_dask_array -from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri +from xarray.core.utils import ( + FrozenDict, + NdimSizeLenMixin, + is_remote_uri, + module_available, +) if TYPE_CHECKING: from io import BufferedIOBase @@ -428,4 +433,24 @@ def guess_can_open( return False -BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {} +class _InternalBackendEntrypoint: + """ + Wrapper class for BackendEntrypoints that ship with xarray. + + + Additional attributes + ---------- + + _module_name : str + Name of the module that is required to enable the backend. + """ + + _module_name: ClassVar[str] + + @classmethod + def _set_availability(cls) -> None: + """Resets the backends availability.""" + cls.available = module_available(cls._module_name) + + +BACKEND_ENTRYPOINTS: dict[str, type[_InternalBackendEntrypoint]] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index c4f75672173..ec561bd7e66 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,8 +8,8 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, - BackendEntrypoint, WritableCFDataStore, + _InternalBackendEntrypoint, _normalize_path, find_root_and_group, ) @@ -343,7 +343,7 @@ def close(self, **kwargs): self._manager.close(**kwargs) -class H5netcdfBackendEntrypoint(BackendEntrypoint): +class H5netcdfBackendEntrypoint(_InternalBackendEntrypoint): """ Backend for netCDF files based on the h5netcdf package. @@ -365,6 +365,7 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): backends.ScipyBackendEntrypoint """ + _module_name = "h5netcdf" available = module_available("h5netcdf") description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray" diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0c6e083158d..1a185aee388 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -11,8 +11,8 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendArray, - BackendEntrypoint, WritableCFDataStore, + _InternalBackendEntrypoint, _normalize_path, find_root_and_group, robust_getitem, @@ -513,7 +513,7 @@ def close(self, **kwargs): self._manager.close(**kwargs) -class NetCDF4BackendEntrypoint(BackendEntrypoint): +class NetCDF4BackendEntrypoint(_InternalBackendEntrypoint): """ Backend for netCDF files based on the netCDF4 package. @@ -535,6 +535,7 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): backends.ScipyBackendEntrypoint """ + _module_name = "netCDF4" available = module_available("netCDF4") description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray" diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index bae1dcd2225..701f1a0fedd 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -6,12 +6,13 @@ import sys import warnings from importlib.metadata import entry_points -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint if TYPE_CHECKING: import os + from importlib.metadata import EntryPoint, EntryPoints from io import BufferedIOBase from xarray.backends.common import AbstractDataStore @@ -19,7 +20,7 @@ STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] -def remove_duplicates(entrypoints): +def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]: # sort and group entrypoints by name entrypoints = sorted(entrypoints, key=lambda ep: ep.name) entrypoints_grouped = itertools.groupby(entrypoints, key=lambda ep: ep.name) @@ -42,7 +43,7 @@ def remove_duplicates(entrypoints): return unique_entrypoints -def detect_parameters(open_dataset): +def detect_parameters(open_dataset: Callable) -> tuple[str, ...]: signature = inspect.signature(open_dataset) parameters = signature.parameters parameters_list = [] @@ -60,7 +61,9 @@ def detect_parameters(open_dataset): return tuple(parameters_list) -def backends_dict_from_pkg(entrypoints): +def backends_dict_from_pkg( + entrypoints: list[EntryPoint], +) -> dict[str, BackendEntrypoint]: backend_entrypoints = {} for entrypoint in entrypoints: name = entrypoint.name @@ -72,14 +75,16 @@ def backends_dict_from_pkg(entrypoints): return backend_entrypoints -def set_missing_parameters(backend_entrypoints): - for name, backend in backend_entrypoints.items(): +def set_missing_parameters(backend_entrypoints: dict[str, BackendEntrypoint]): + for _, backend in backend_entrypoints.items(): if backend.open_dataset_parameters is None: open_dataset = backend.open_dataset backend.open_dataset_parameters = detect_parameters(open_dataset) -def sort_backends(backend_entrypoints): +def sort_backends( + backend_entrypoints: dict[str, BackendEntrypoint] +) -> dict[str, BackendEntrypoint]: ordered_backends_entrypoints = {} for be_name in STANDARD_BACKENDS_ORDER: if backend_entrypoints.get(be_name, None) is not None: @@ -90,7 +95,7 @@ def sort_backends(backend_entrypoints): return ordered_backends_entrypoints -def build_engines(entrypoints) -> dict[str, BackendEntrypoint]: +def build_engines(entrypoints: EntryPoints) -> dict[str, BackendEntrypoint]: backend_entrypoints = {} for backend_name, backend in BACKEND_ENTRYPOINTS.items(): if backend.available: @@ -126,6 +131,13 @@ def list_engines() -> dict[str, BackendEntrypoint]: return build_engines(entrypoints) +def refresh_engines() -> None: + """Refreshes the backend engines based on installed packages.""" + list_engines.cache_clear() + for backend_entrypoint in BACKEND_ENTRYPOINTS.values(): + backend_entrypoint._set_availability() + + def guess_engine( store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, ): diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index ae8f90e3a44..d64f99a5220 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -6,7 +6,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - BackendEntrypoint, + _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.file_manager import CachingFileManager @@ -96,7 +96,7 @@ def close(self): self._manager.close() -class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): +class PseudoNetCDFBackendEntrypoint(_InternalBackendEntrypoint): """ Backend for netCDF-like data formats in the air quality field based on the PseudoNetCDF package. @@ -121,6 +121,7 @@ class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): backends.PseudoNetCDFDataStore """ + _module_name = "PseudoNetCDF" available = module_available("PseudoNetCDF") description = ( "Open many atmospheric science data formats using PseudoNetCDF in Xarray" diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index df26a03d790..7452a23a174 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -7,7 +7,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - BackendEntrypoint, + _InternalBackendEntrypoint, robust_getitem, ) from xarray.backends.store import StoreBackendEntrypoint @@ -138,7 +138,7 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) -class PydapBackendEntrypoint(BackendEntrypoint): +class PydapBackendEntrypoint(_InternalBackendEntrypoint): """ Backend for steaming datasets over the internet using the Data Access Protocol, also known as DODS or OPeNDAP @@ -154,6 +154,7 @@ class PydapBackendEntrypoint(BackendEntrypoint): backends.PydapDataStore """ + _module_name = "pydap" available = module_available("pydap") description = "Open remote datasets via OPeNDAP using pydap in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html" diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 611ea978990..469df3d6ba9 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -8,7 +8,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - BackendEntrypoint, + _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.file_manager import CachingFileManager @@ -107,7 +107,7 @@ def close(self): self._manager.close() -class PynioBackendEntrypoint(BackendEntrypoint): +class PynioBackendEntrypoint(_InternalBackendEntrypoint): """ PyNIO backend @@ -117,6 +117,7 @@ class PynioBackendEntrypoint(BackendEntrypoint): https://github.com/pydata/xarray/issues/4491 for more information """ + _module_name = "Nio" available = module_available("Nio") def open_dataset( diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 651aebce2ce..e34618c7d57 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -9,8 +9,8 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendArray, - BackendEntrypoint, WritableCFDataStore, + _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.file_manager import CachingFileManager, DummyFileManager @@ -240,7 +240,7 @@ def close(self): self._manager.close() -class ScipyBackendEntrypoint(BackendEntrypoint): +class ScipyBackendEntrypoint(_InternalBackendEntrypoint): """ Backend for netCDF files based on the scipy package. @@ -261,6 +261,7 @@ class ScipyBackendEntrypoint(BackendEntrypoint): backends.H5netcdfBackendEntrypoint """ + _module_name = "scipy" available = module_available("scipy") description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html" diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6686d67ed4d..5f96491e104 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -11,8 +11,8 @@ BACKEND_ENTRYPOINTS, AbstractWritableDataStore, BackendArray, - BackendEntrypoint, _encode_variable_name, + _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.store import StoreBackendEntrypoint @@ -845,7 +845,7 @@ def open_zarr( return ds -class ZarrBackendEntrypoint(BackendEntrypoint): +class ZarrBackendEntrypoint(_InternalBackendEntrypoint): """ Backend for ".zarr" files based on the zarr package. @@ -857,6 +857,7 @@ class ZarrBackendEntrypoint(BackendEntrypoint): backends.ZarrStore """ + _module_name = "zarr" available = module_available("zarr") description = "Open zarr files (.zarr) using zarr in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html" From b2f1cfb21a95c17a6a1547c21d77c89531c5db32 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 12 Feb 2023 17:05:33 +0100 Subject: [PATCH 9/9] Revert "allow refreshing of backends" This reverts commit 576692bce54b3ae9bc4af79646a0168587002379. --- xarray/backends/cfgrib_.py | 5 ++--- xarray/backends/common.py | 29 ++--------------------------- xarray/backends/h5netcdf_.py | 5 ++--- xarray/backends/netCDF4_.py | 5 ++--- xarray/backends/plugins.py | 28 ++++++++-------------------- xarray/backends/pseudonetcdf_.py | 5 ++--- xarray/backends/pydap_.py | 5 ++--- xarray/backends/pynio_.py | 5 ++--- xarray/backends/scipy_.py | 5 ++--- xarray/backends/zarr.py | 5 ++--- 10 files changed, 26 insertions(+), 71 deletions(-) diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index e75119cb1aa..4c7d6a65e8f 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -9,7 +9,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - _InternalBackendEntrypoint, + BackendEntrypoint, _normalize_path, ) from xarray.backends.locks import SerializableLock, ensure_lock @@ -90,8 +90,7 @@ def get_encoding(self): return {"unlimited_dims": {k for k, v in dims.items() if v is None}} -class CfgribfBackendEntrypoint(_InternalBackendEntrypoint): - _module_name = "cfgrib" +class CfgribfBackendEntrypoint(BackendEntrypoint): available = module_available("cfgrib") def guess_can_open(self, filename_or_obj): diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 27411b58062..050493e3034 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -12,12 +12,7 @@ from xarray.conventions import cf_encoder from xarray.core import indexing from xarray.core.pycompat import is_duck_dask_array -from xarray.core.utils import ( - FrozenDict, - NdimSizeLenMixin, - is_remote_uri, - module_available, -) +from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri if TYPE_CHECKING: from io import BufferedIOBase @@ -433,24 +428,4 @@ def guess_can_open( return False -class _InternalBackendEntrypoint: - """ - Wrapper class for BackendEntrypoints that ship with xarray. - - - Additional attributes - ---------- - - _module_name : str - Name of the module that is required to enable the backend. - """ - - _module_name: ClassVar[str] - - @classmethod - def _set_availability(cls) -> None: - """Resets the backends availability.""" - cls.available = module_available(cls._module_name) - - -BACKEND_ENTRYPOINTS: dict[str, type[_InternalBackendEntrypoint]] = {} +BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index ec561bd7e66..c4f75672173 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,8 +8,8 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, + BackendEntrypoint, WritableCFDataStore, - _InternalBackendEntrypoint, _normalize_path, find_root_and_group, ) @@ -343,7 +343,7 @@ def close(self, **kwargs): self._manager.close(**kwargs) -class H5netcdfBackendEntrypoint(_InternalBackendEntrypoint): +class H5netcdfBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the h5netcdf package. @@ -365,7 +365,6 @@ class H5netcdfBackendEntrypoint(_InternalBackendEntrypoint): backends.ScipyBackendEntrypoint """ - _module_name = "h5netcdf" available = module_available("h5netcdf") description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray" diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 1a185aee388..0c6e083158d 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -11,8 +11,8 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendArray, + BackendEntrypoint, WritableCFDataStore, - _InternalBackendEntrypoint, _normalize_path, find_root_and_group, robust_getitem, @@ -513,7 +513,7 @@ def close(self, **kwargs): self._manager.close(**kwargs) -class NetCDF4BackendEntrypoint(_InternalBackendEntrypoint): +class NetCDF4BackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the netCDF4 package. @@ -535,7 +535,6 @@ class NetCDF4BackendEntrypoint(_InternalBackendEntrypoint): backends.ScipyBackendEntrypoint """ - _module_name = "netCDF4" available = module_available("netCDF4") description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray" diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 701f1a0fedd..bae1dcd2225 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -6,13 +6,12 @@ import sys import warnings from importlib.metadata import entry_points -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint if TYPE_CHECKING: import os - from importlib.metadata import EntryPoint, EntryPoints from io import BufferedIOBase from xarray.backends.common import AbstractDataStore @@ -20,7 +19,7 @@ STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] -def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]: +def remove_duplicates(entrypoints): # sort and group entrypoints by name entrypoints = sorted(entrypoints, key=lambda ep: ep.name) entrypoints_grouped = itertools.groupby(entrypoints, key=lambda ep: ep.name) @@ -43,7 +42,7 @@ def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]: return unique_entrypoints -def detect_parameters(open_dataset: Callable) -> tuple[str, ...]: +def detect_parameters(open_dataset): signature = inspect.signature(open_dataset) parameters = signature.parameters parameters_list = [] @@ -61,9 +60,7 @@ def detect_parameters(open_dataset: Callable) -> tuple[str, ...]: return tuple(parameters_list) -def backends_dict_from_pkg( - entrypoints: list[EntryPoint], -) -> dict[str, BackendEntrypoint]: +def backends_dict_from_pkg(entrypoints): backend_entrypoints = {} for entrypoint in entrypoints: name = entrypoint.name @@ -75,16 +72,14 @@ def backends_dict_from_pkg( return backend_entrypoints -def set_missing_parameters(backend_entrypoints: dict[str, BackendEntrypoint]): - for _, backend in backend_entrypoints.items(): +def set_missing_parameters(backend_entrypoints): + for name, backend in backend_entrypoints.items(): if backend.open_dataset_parameters is None: open_dataset = backend.open_dataset backend.open_dataset_parameters = detect_parameters(open_dataset) -def sort_backends( - backend_entrypoints: dict[str, BackendEntrypoint] -) -> dict[str, BackendEntrypoint]: +def sort_backends(backend_entrypoints): ordered_backends_entrypoints = {} for be_name in STANDARD_BACKENDS_ORDER: if backend_entrypoints.get(be_name, None) is not None: @@ -95,7 +90,7 @@ def sort_backends( return ordered_backends_entrypoints -def build_engines(entrypoints: EntryPoints) -> dict[str, BackendEntrypoint]: +def build_engines(entrypoints) -> dict[str, BackendEntrypoint]: backend_entrypoints = {} for backend_name, backend in BACKEND_ENTRYPOINTS.items(): if backend.available: @@ -131,13 +126,6 @@ def list_engines() -> dict[str, BackendEntrypoint]: return build_engines(entrypoints) -def refresh_engines() -> None: - """Refreshes the backend engines based on installed packages.""" - list_engines.cache_clear() - for backend_entrypoint in BACKEND_ENTRYPOINTS.values(): - backend_entrypoint._set_availability() - - def guess_engine( store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, ): diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d64f99a5220..ae8f90e3a44 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -6,7 +6,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - _InternalBackendEntrypoint, + BackendEntrypoint, _normalize_path, ) from xarray.backends.file_manager import CachingFileManager @@ -96,7 +96,7 @@ def close(self): self._manager.close() -class PseudoNetCDFBackendEntrypoint(_InternalBackendEntrypoint): +class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF-like data formats in the air quality field based on the PseudoNetCDF package. @@ -121,7 +121,6 @@ class PseudoNetCDFBackendEntrypoint(_InternalBackendEntrypoint): backends.PseudoNetCDFDataStore """ - _module_name = "PseudoNetCDF" available = module_available("PseudoNetCDF") description = ( "Open many atmospheric science data formats using PseudoNetCDF in Xarray" diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 7452a23a174..df26a03d790 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -7,7 +7,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - _InternalBackendEntrypoint, + BackendEntrypoint, robust_getitem, ) from xarray.backends.store import StoreBackendEntrypoint @@ -138,7 +138,7 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) -class PydapBackendEntrypoint(_InternalBackendEntrypoint): +class PydapBackendEntrypoint(BackendEntrypoint): """ Backend for steaming datasets over the internet using the Data Access Protocol, also known as DODS or OPeNDAP @@ -154,7 +154,6 @@ class PydapBackendEntrypoint(_InternalBackendEntrypoint): backends.PydapDataStore """ - _module_name = "pydap" available = module_available("pydap") description = "Open remote datasets via OPeNDAP using pydap in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html" diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 469df3d6ba9..611ea978990 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -8,7 +8,7 @@ BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, - _InternalBackendEntrypoint, + BackendEntrypoint, _normalize_path, ) from xarray.backends.file_manager import CachingFileManager @@ -107,7 +107,7 @@ def close(self): self._manager.close() -class PynioBackendEntrypoint(_InternalBackendEntrypoint): +class PynioBackendEntrypoint(BackendEntrypoint): """ PyNIO backend @@ -117,7 +117,6 @@ class PynioBackendEntrypoint(_InternalBackendEntrypoint): https://github.com/pydata/xarray/issues/4491 for more information """ - _module_name = "Nio" available = module_available("Nio") def open_dataset( diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index e34618c7d57..651aebce2ce 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -9,8 +9,8 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendArray, + BackendEntrypoint, WritableCFDataStore, - _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.file_manager import CachingFileManager, DummyFileManager @@ -240,7 +240,7 @@ def close(self): self._manager.close() -class ScipyBackendEntrypoint(_InternalBackendEntrypoint): +class ScipyBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the scipy package. @@ -261,7 +261,6 @@ class ScipyBackendEntrypoint(_InternalBackendEntrypoint): backends.H5netcdfBackendEntrypoint """ - _module_name = "scipy" available = module_available("scipy") description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html" diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 5f96491e104..6686d67ed4d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -11,8 +11,8 @@ BACKEND_ENTRYPOINTS, AbstractWritableDataStore, BackendArray, + BackendEntrypoint, _encode_variable_name, - _InternalBackendEntrypoint, _normalize_path, ) from xarray.backends.store import StoreBackendEntrypoint @@ -845,7 +845,7 @@ def open_zarr( return ds -class ZarrBackendEntrypoint(_InternalBackendEntrypoint): +class ZarrBackendEntrypoint(BackendEntrypoint): """ Backend for ".zarr" files based on the zarr package. @@ -857,7 +857,6 @@ class ZarrBackendEntrypoint(_InternalBackendEntrypoint): backends.ZarrStore """ - _module_name = "zarr" available = module_available("zarr") description = "Open zarr files (.zarr) using zarr in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html"