diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 1d538bf94ed..ddfe4847225 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -39,7 +39,7 @@ from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk from xarray.core.indexes import Index from xarray.core.parallelcompat import guess_chunkmanager -from xarray.core.types import ZarrWriteModes +from xarray.core.types import ZarrOpenModes from xarray.core.utils import is_remote_uri if TYPE_CHECKING: @@ -55,20 +55,21 @@ CompatOptions, JoinOptions, NestedSequence, + NetcdfFormats, T_Chunks, + T_XarrayCanOpen, ) T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ T_NetcdfEngine, Literal["pydap", "pynio", "zarr"], + BackendEntrypoint, type[BackendEntrypoint], str, # no nice typing support for custom backends None, ] - T_NetcdfTypes = Literal[ - "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" - ] + DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -390,7 +391,7 @@ def _dataset_from_backend_dataset( def open_dataset( - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, engine: T_Engine = None, chunks: T_Chunks = None, @@ -421,11 +422,10 @@ def open_dataset( objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ "zarr", None}, installed backend \ - or subclass of xarray.backends.BackendEntrypoint, optional + or instance or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for - "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) - can also be used. + "netcdf4". chunks : int, dict, 'auto' or None, optional If chunks is provided, it is used to load the new dataset into dask arrays. ``chunks=-1`` loads the dataset with dask using a single @@ -595,8 +595,8 @@ def open_dataset( def open_dataarray( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, - engine: T_Engine | None = None, - chunks: T_Chunks | None = None, + engine: T_Engine = None, + chunks: T_Chunks = None, cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, @@ -628,7 +628,7 @@ def open_dataarray( objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ "zarr", None}, installed backend \ - or subclass of xarray.backends.BackendEntrypoint, optional + or instance or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for "netcdf4". @@ -707,16 +707,20 @@ def open_dataarray( in the values of the task graph. See :py:func:`dask.array.from_array`. chunked_array_type: str, optional Which chunked array type to coerce the underlying data array to. - Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Defaults to 'dask' if installed, else whatever is registered via the + `ChunkManagerEnetryPoint` system. Experimental API that should not be relied upon. from_array_kwargs: dict - Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create - chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. - For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed - to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` + method used to create chunked arrays, via whichever chunk manager is + specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, + additional kwargs will be passed to :py:func:`dask.array.from_array`. + Experimental API that should not be relied upon. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, - equivalent to `**kwargs`. + equivalent to `**kwargs`. Alternatively pass a configured Backend object + as engine. **kwargs: dict Additional keyword arguments passed on to the engine open function. For example: @@ -729,7 +733,8 @@ def open_dataarray( currently active dask scheduler. Supported by "netcdf4", "h5netcdf", "scipy", "pynio". - See engine open function for kwargs accepted by each specific engine. + See engine open function for kwargs accepted by each specific engine or + create an instance of the Backend and configure it in the constructor. Notes ----- @@ -790,7 +795,7 @@ def open_dataarray( def open_mfdataset( paths: str | NestedSequence[str | os.PathLike], - chunks: T_Chunks | None = None, + chunks: T_Chunks = None, concat_dim: str | DataArray | Index @@ -800,7 +805,7 @@ def open_mfdataset( | None = None, compat: CompatOptions = "no_conflicts", preprocess: Callable[[Dataset], Dataset] | None = None, - engine: T_Engine | None = None, + engine: T_Engine = None, data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords="different", combine: Literal["by_coords", "nested"] = "by_coords", @@ -868,7 +873,7 @@ def open_mfdataset( ``ds.encoding["source"]``. engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ "zarr", None}, installed backend \ - or subclass of xarray.backends.BackendEntrypoint, optional + or instance or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for "netcdf4". @@ -1092,7 +1097,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1111,7 +1116,7 @@ def to_netcdf( dataset: Dataset, path_or_file: None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1129,7 +1134,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1148,7 +1153,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1167,7 +1172,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1186,7 +1191,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1204,7 +1209,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1220,7 +1225,7 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -1633,14 +1638,14 @@ def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, *, compute: Literal[True] = True, consolidated: bool | None = None, - append_dim: Hashable | None = None, + append_dim: str | None = None, region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, @@ -1657,14 +1662,14 @@ def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, *, compute: Literal[False], consolidated: bool | None = None, - append_dim: Hashable | None = None, + append_dim: str | None = None, region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, @@ -1679,14 +1684,14 @@ def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, *, compute: bool = True, consolidated: bool | None = None, - append_dim: Hashable | None = None, + append_dim: str | None = None, region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 5b8f9a6840f..04d3896d00a 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -6,7 +6,7 @@ import traceback from collections.abc import Iterable from glob import glob -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, ClassVar, TypeVar, overload import numpy as np @@ -14,13 +14,12 @@ from xarray.core import indexing from xarray.core.parallelcompat import get_chunked_array_type from xarray.core.pycompat import is_chunked_array +from xarray.core.types import T_BackendDatasetLike from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri if TYPE_CHECKING: - from io import BufferedIOBase - from xarray.core.dataset import Dataset - from xarray.core.types import NestedSequence + from xarray.core.types import NestedSequence, T_XarrayCanOpen # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -28,8 +27,20 @@ NONE_VAR_NAME = "__values__" +T = TypeVar("T") + + +@overload +def _normalize_path(path: os.PathLike) -> str: # type: ignore[overload-overlap] + ... + + +@overload +def _normalize_path(path: T) -> T: + ... -def _normalize_path(path): + +def _normalize_path(path: os.PathLike | T) -> str | T: """ Normalize pathlikes to string. @@ -52,9 +63,9 @@ def _normalize_path(path): path = os.fspath(path) if isinstance(path, str) and not is_remote_uri(path): - path = os.path.abspath(os.path.expanduser(path)) + return os.path.abspath(os.path.expanduser(path)) - return path + return path # type: ignore[return-value] def _find_absolute_paths( @@ -127,9 +138,9 @@ def _decode_variable_name(name): return name -def find_root_and_group(ds): +def find_root_and_group(ds: T_BackendDatasetLike) -> tuple[T_BackendDatasetLike, str]: """Find the root and group name of a netCDF4/h5netcdf dataset.""" - hierarchy = () + hierarchy: tuple[str, ...] = () while ds.parent is not None: hierarchy = (ds.name.split("/")[-1],) + hierarchy ds = ds.parent @@ -456,26 +467,27 @@ class BackendEntrypoint: ``drop_variables`` keyword argument. For more details see :ref:`RST open_dataset`. - ``guess_can_open`` method: it shall return ``True`` if the backend is able to open - ``filename_or_obj``, ``False`` otherwise. The implementation of this + ``filename_or_obj`` , ``False`` otherwise. The implementation of this method is not mandatory. Attributes ---------- - open_dataset_parameters : tuple, default: None - A list of ``open_dataset`` method parameters. - The setting of this attribute is not mandatory. description : str, default: "" A short string describing the engine. The setting of this attribute is not mandatory. url : str, default: "" A string with the URL to the backend's documentation. The setting of this attribute is not mandatory. + open_dataset_parameters : tuple of str or None, optional + A list of ``open_dataset`` method parameters. + The setting of this attribute is only mandatory if the + open_dataset method contains ``*args`` or ``**kwargs``. """ - open_dataset_parameters: ClassVar[tuple | None] = None description: ClassVar[str] = "" url: ClassVar[str] = "" + open_dataset_parameters: ClassVar[tuple[str, ...] | None] = None def __repr__(self) -> str: txt = f"<{type(self).__name__}>" @@ -487,10 +499,9 @@ def __repr__(self) -> str: def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, drop_variables: str | Iterable[str] | None = None, - **kwargs: Any, ) -> Dataset: """ Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. @@ -500,7 +511,7 @@ def open_dataset( def guess_can_open( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, ) -> bool: """ Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index df901f9a1d9..34d0f8f5d02 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -2,30 +2,30 @@ import atexit import contextlib -import io import threading import uuid import warnings -from collections.abc import Hashable -from typing import Any +from collections.abc import Hashable, Iterator, MutableMapping, Sequence +from typing import Any, Callable, Generic, Literal, Union, cast from xarray.backends.locks import acquire from xarray.backends.lru_cache import LRUCache from xarray.core import utils from xarray.core.options import OPTIONS +from xarray.core.types import FileLike, LockLike, T_FileLike, TypeAlias # Global cache for storing open files. -FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache( +FILE_CACHE: LRUCache[Hashable, FileLike] = LRUCache( maxsize=OPTIONS["file_cache_maxsize"], on_evict=lambda k, v: v.close() ) assert FILE_CACHE.maxsize, "file cache must be at least size one" -REF_COUNTS: dict[Any, int] = {} +REF_COUNTS: dict[Hashable, int] = {} _DEFAULT_MODE = utils.ReprObject("") -class FileManager: +class FileManager(Generic[T_FileLike]): """Manager for acquiring and closing a file object. Use FileManager subclasses (CachingFileManager in particular) on backend @@ -33,11 +33,12 @@ class FileManager: many open files and transferring them between multiple processes. """ - def acquire(self, needs_lock=True): + def acquire(self, needs_lock: bool = True) -> T_FileLike: """Acquire the file object from this manager.""" raise NotImplementedError() - def acquire_context(self, needs_lock=True): + @contextlib.contextmanager + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_FileLike]: """Context manager for acquiring a file. Yields a file object. The context manager unwinds any actions taken as part of acquisition @@ -46,12 +47,22 @@ def acquire_context(self, needs_lock=True): """ raise NotImplementedError() - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: """Close the file object associated with this manager, if needed.""" raise NotImplementedError() -class CachingFileManager(FileManager): +_CachingFileManagerState: TypeAlias = tuple[ + Callable[..., T_FileLike], + tuple[Any, ...], + Union[str, utils.ReprObject], + dict[str, Any], + Union[LockLike, None], + Hashable, +] + + +class CachingFileManager(FileManager, Generic[T_FileLike]): """Wrapper for automatically opening and closing file objects. Unlike files, CachingFileManager objects can be safely pickled and passed @@ -79,17 +90,28 @@ class CachingFileManager(FileManager): """ + _opener: Callable[..., T_FileLike] + _args: tuple[Any, ...] + _mode: str | utils.ReprObject + _kwargs: dict[str, Any] + _use_default_lock: bool + _lock: LockLike + _cache: MutableMapping[Hashable, FileLike] + _manager_id: Hashable + _key: Hashable + _ref_counter: _RefCounter + def __init__( self, - opener, - *args, - mode=_DEFAULT_MODE, - kwargs=None, - lock=None, - cache=None, + opener: Callable[..., T_FileLike], + *args: Any, + mode: str | utils.ReprObject = _DEFAULT_MODE, + kwargs: dict[str, Any] | None = None, + lock: Literal[False] | LockLike | None = None, + cache: MutableMapping[Hashable, FileLike] | None = None, manager_id: Hashable | None = None, - ref_counts=None, - ): + ref_counts: dict[Hashable, int] | None = None, + ) -> None: """Initialize a CachingFileManager. The cache, manager_id and ref_counts arguments exist solely to @@ -135,16 +157,12 @@ def __init__( self._kwargs = {} if kwargs is None else dict(kwargs) self._use_default_lock = lock is None or lock is False - self._lock = threading.Lock() if self._use_default_lock else lock + self._lock = threading.Lock() if lock is None or lock is False else lock # cache[self._key] stores the file associated with this object. - if cache is None: - cache = FILE_CACHE - self._cache = cache - if manager_id is None: - # Each call to CachingFileManager should separately open files. - manager_id = str(uuid.uuid4()) - self._manager_id = manager_id + self._cache = FILE_CACHE if cache is None else cache + # Each call to CachingFileManager should separately open files. + self._manager_id = str(uuid.uuid4()) if manager_id is None else manager_id self._key = self._make_key() # ref_counts[self._key] stores the number of CachingFileManager objects @@ -155,7 +173,7 @@ def __init__( self._ref_counter = _RefCounter(ref_counts) self._ref_counter.increment(self._key) - def _make_key(self): + def _make_key(self) -> Hashable: """Make a key for caching files in the LRU cache.""" value = ( self._opener, @@ -167,7 +185,7 @@ def _make_key(self): return _HashedSequence(value) @contextlib.contextmanager - def _optional_lock(self, needs_lock): + def _optional_lock(self, needs_lock: bool) -> Iterator[None]: """Context manager for optionally acquiring a lock.""" if needs_lock: with self._lock: @@ -175,7 +193,7 @@ def _optional_lock(self, needs_lock): else: yield - def acquire(self, needs_lock=True): + def acquire(self, needs_lock: bool = True) -> T_FileLike: """Acquire a file object from the manager. A new file is only opened if it has expired from the @@ -194,7 +212,7 @@ def acquire(self, needs_lock=True): return file @contextlib.contextmanager - def acquire_context(self, needs_lock=True): + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_FileLike]: """Context manager for acquiring a file.""" file, cached = self._acquire_with_cache_info(needs_lock) try: @@ -204,7 +222,9 @@ def acquire_context(self, needs_lock=True): self.close(needs_lock) raise - def _acquire_with_cache_info(self, needs_lock=True): + def _acquire_with_cache_info( + self, needs_lock: bool = True + ) -> tuple[T_FileLike, bool]: """Acquire a file, returning the file and whether it was cached.""" with self._optional_lock(needs_lock): try: @@ -221,9 +241,9 @@ def _acquire_with_cache_info(self, needs_lock=True): self._cache[self._key] = file return file, False else: - return file, True + return cast(T_FileLike, file), True - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: """Explicitly close any associated file object (if necessary).""" # TODO: remove needs_lock if/when we have a reentrant lock in # dask.distributed: https://github.com/dask/dask/issues/3832 @@ -259,7 +279,7 @@ def __del__(self) -> None: stacklevel=2, ) - def __getstate__(self): + def __getstate__(self) -> _CachingFileManagerState: """State for pickling.""" # cache is intentionally omitted: we don't want to try to serialize # these global objects. @@ -273,7 +293,7 @@ def __getstate__(self): self._manager_id, ) - def __setstate__(self, state) -> None: + def __setstate__(self, state: _CachingFileManagerState) -> None: """Restore from a pickle.""" opener, args, mode, kwargs, lock, manager_id = state self.__init__( # type: ignore @@ -300,16 +320,19 @@ def _remove_del_method(): class _RefCounter: """Class for keeping track of reference counts.""" - def __init__(self, counts): + _counts: dict[Hashable, int] + _lock: threading.Lock + + def __init__(self, counts: dict[Hashable, int]) -> None: self._counts = counts self._lock = threading.Lock() - def increment(self, name): + def increment(self, name: Hashable) -> int: with self._lock: count = self._counts[name] = self._counts.get(name, 0) + 1 return count - def decrement(self, name): + def decrement(self, name: Hashable) -> int: with self._lock: count = self._counts[name] - 1 if count: @@ -328,29 +351,33 @@ class _HashedSequence(list): https://bugs.python.org/issue1462796 """ - def __init__(self, tuple_value): + hashvalue: int + + def __init__(self, tuple_value: Sequence[Hashable]): self[:] = tuple_value self.hashvalue = hash(tuple_value) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return self.hashvalue -class DummyFileManager(FileManager): +class DummyFileManager(FileManager, Generic[T_FileLike]): """FileManager that simply wraps an open file in the FileManager interface.""" - def __init__(self, value): + _value: T_FileLike + + def __init__(self, value: T_FileLike) -> None: self._value = value - def acquire(self, needs_lock=True): + def acquire(self, needs_lock=True) -> T_FileLike: del needs_lock # ignored return self._value @contextlib.contextmanager - def acquire_context(self, needs_lock=True): + def acquire_context(self, needs_lock: bool = True) -> Iterator[T_FileLike]: del needs_lock yield self._value - def close(self, needs_lock=True): + def close(self, needs_lock: bool = True) -> None: del needs_lock # ignored self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index d9385fc68a9..5d25f94f190 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -3,8 +3,9 @@ import functools import io import os -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal from xarray.backends.common import ( BACKEND_ENTRYPOINTS, @@ -27,6 +28,7 @@ from xarray.core import indexing from xarray.core.utils import ( FrozenDict, + hashable, is_remote_uri, read_magic_number_from_file, try_read_magic_number_from_file_or_path, @@ -34,10 +36,11 @@ from xarray.core.variable import Variable if TYPE_CHECKING: - from io import BufferedIOBase + import h5netcdf - from xarray.backends.common import AbstractDataStore + from xarray.backends.file_manager import FileManager from xarray.core.dataset import Dataset + from xarray.core.types import H5netcdfOpenModes, LockLike, Self, T_XarrayCanOpen class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -92,17 +95,35 @@ class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf""" __slots__ = ( + "_manager", + "_group", + "_mode", + "_filename", "autoclose", "format", "is_remote", "lock", - "_filename", - "_group", - "_manager", - "_mode", ) - def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): + _manager: FileManager[h5netcdf.File | h5netcdf.Group] + _group: str | None + _mode: H5netcdfOpenModes + _filename: str + autoclose: bool + format: None + is_remote: bool + lock: LockLike + + def __init__( + self, + manager: h5netcdf.File + | h5netcdf.Group + | FileManager[h5netcdf.File | h5netcdf.Group], + group: str | None = None, + mode: H5netcdfOpenModes = "r", + lock: Literal[False] | LockLike | None = HDF5_LOCK, + autoclose: bool = False, + ): import h5netcdf if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): @@ -131,18 +152,18 @@ def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=Fal @classmethod def open( cls, - filename, - mode="r", - format=None, - group=None, - lock=None, - autoclose=False, - invalid_netcdf=None, - phony_dims=None, - decode_vlen_strings=True, - driver=None, - driver_kwds=None, - ): + filename: T_XarrayCanOpen, + mode: H5netcdfOpenModes = "r", + format: str | None = None, + group: str | None = None, + lock: Literal[False] | LockLike | None = None, + autoclose: bool = False, + invalid_netcdf: bool | None = None, + phony_dims: Literal["sort", "access", None] = None, + decode_vlen_strings: bool = True, + driver: str | None = None, + driver_kwds: Mapping[str, Any] | None = None, + ) -> Self: import h5netcdf if isinstance(filename, bytes): @@ -154,7 +175,7 @@ def open( magic_number = read_magic_number_from_file(filename) if not magic_number.startswith(b"\211HDF\r\n\032\n"): raise ValueError( - f"{magic_number} is not the signature of a valid netCDF4 file" + f"{magic_number.decode()} is not the signature of a valid netCDF4 file" ) if format not in [None, "NETCDF4"]: @@ -174,12 +195,13 @@ def open( if mode == "r": lock = HDF5_LOCK else: + assert hashable(filename) lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) - def _acquire(self, needs_lock=True): + def _acquire(self, needs_lock: bool = True) -> h5netcdf.File | h5netcdf.Group: with self._manager.acquire_context(needs_lock) as root: ds = _nc4_require_group( root, self._group, self._mode, create_group=_h5netcdf_create_group @@ -187,7 +209,7 @@ def _acquire(self, needs_lock=True): return ds @property - def ds(self): + def ds(self) -> h5netcdf.File | h5netcdf.Group: return self._acquire() def open_store_variable(self, name, var): @@ -335,6 +357,7 @@ def close(self, **kwargs): self._manager.close(**kwargs) +@dataclass(repr=False) class H5netcdfBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the h5netcdf package. @@ -350,6 +373,55 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): For more information about the underlying library, visit: https://h5netcdf.org + Parameters + ---------- + group: str or None, optional + Path to the netCDF4 group in the given file to open. None (default) uses + the root group. + mode: {"w", "a", "r+", "r"}, default: "r" + Access mode of the NetCDF file. "r" means read-only; no data can be + modified. "w" means write; a new file is created, an existing file with + the same name is deleted. "a" and "r+" mean append; an existing file is + opened for reading and writing, if file does not exist already, one is + created. + format: "NETCDF4", or None, optional + Format of the NetCDF file. Only "NETCDF4" is supported by h5netcdf. + lock: False, None or Lock-like, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. If None (default) appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. + autoclose: bool, default: False + If True, automatically close files to avoid OS Error of too many files + being open. However, this option doesn't work with streams, e.g., + BytesIO. + invalid_netcdf : bool or None, optional + Allow writing netCDF4 with data types and attributes that would + otherwise not generate netCDF4 files that can be read by other + applications. See https://h5netcdf.org/#invalid-netcdf-files for + more details. + phony_dims: {"sort", "access"} or None, optional + Change how variables with no dimension scales associated with + one of their axes are accessed. + + - None: raises a ValueError (default) + + - "sort": invent phony dimensions according to netCDF behaviour. + Note, that this iterates once over the whole group-hierarchy. + This has affects on performance in case you rely on laziness + of group access. + + - "access": defer phony dimension creation to group access time. + The created phony dimension naming will differ from netCDF behaviour. + + decode_vlen_strings: bool, default: True + Return vlen string data as str instead of bytes. + driver: str or None, optional + Name of the driver to use. Legal values are None (default, + recommended), "core", "sec2", "direct", "stdio", "mpio", "ros3". + driver_kwds: Mapping or None, optional + Additional driver options. See h5py.File for more infos. + See Also -------- backends.H5NetCDFStore @@ -361,11 +433,28 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray" ) url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html" + open_dataset_parameters = ( + "drop_variables", + "mask_and_scale", + "decode_times", + "concat_characters", + "use_cftime", + "decode_timedelta", + "decode_coords", + ) - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ) -> bool: + group: str | None = None + mode: H5netcdfOpenModes = "r" + format: str | None = "NETCDF4" + lock: Literal[False] | LockLike | None = None + autoclose: bool = False + invalid_netcdf: bool | None = None + phony_dims: Literal["sort", "access", None] = None + decode_vlen_strings: bool = True + driver: str | None = None + driver_kwds: Mapping[str, Any] | None = None + + def guess_can_open(self, filename_or_obj: T_XarrayCanOpen) -> bool: magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None: return magic_number.startswith(b"\211HDF\r\n\032\n") @@ -376,41 +465,39 @@ def guess_can_open( return False - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - format=None, - group=None, - lock=None, - invalid_netcdf=None, - phony_dims=None, - decode_vlen_strings=True, - driver=None, - driver_kwds=None, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, + decode_coords: bool | Literal["coordinates", "all"] = True, + **kwargs: Any, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, - format=format, - group=group, - lock=lock, - invalid_netcdf=invalid_netcdf, - phony_dims=phony_dims, - decode_vlen_strings=decode_vlen_strings, - driver=driver, - driver_kwds=driver_kwds, + mode=kwargs.pop("mode", self.mode), + format=kwargs.pop("format", self.format), + group=kwargs.pop("group", self.group), + lock=kwargs.pop("lock", self.lock), + autoclose=kwargs.pop("autoclose", self.autoclose), + invalid_netcdf=kwargs.pop("invalid_netcdf", self.invalid_netcdf), + phony_dims=kwargs.pop("phony_dims", self.phony_dims), + decode_vlen_strings=kwargs.pop( + "decode_vlen_strings", self.decode_vlen_strings + ), + driver=kwargs.pop("driver", self.driver), + driver_kwds=kwargs.pop("driver_kwds", self.driver_kwds), ) + if kwargs: + raise ValueError(f"Unsupported kwargs: {kwargs.values()}") store_entrypoint = StoreBackendEntrypoint() - ds = store_entrypoint.open_dataset( store, mask_and_scale=mask_and_scale, diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 045ee522fa8..2265373bc85 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -4,9 +4,15 @@ import threading import uuid import weakref -from collections.abc import Hashable, MutableMapping -from typing import Any, ClassVar -from weakref import WeakValueDictionary +from collections.abc import Hashable, Iterable +from typing import TYPE_CHECKING, Callable, ClassVar, Literal, overload + +if TYPE_CHECKING: + from xarray.core.types import LockLike, T_LockLike, TypeAlias + + SchedulerOptions: TypeAlias = Literal[ + "threaded", "multiprocessing", "distributed", None + ] # SerializableLock is adapted from Dask: @@ -41,12 +47,15 @@ class SerializableLock: """ _locks: ClassVar[ - WeakValueDictionary[Hashable, threading.Lock] - ] = WeakValueDictionary() + weakref.WeakValueDictionary[Hashable, threading.Lock] + ] = weakref.WeakValueDictionary() token: Hashable lock: threading.Lock - def __init__(self, token: Hashable | None = None): + def __init__(self, token: Hashable = None): + self._set_token_and_lock(token) + + def _set_token_and_lock(self, token: Hashable) -> None: self.token = token or str(uuid.uuid4()) if self.token in SerializableLock._locks: self.lock = SerializableLock._locks[self.token] @@ -54,31 +63,32 @@ def __init__(self, token: Hashable | None = None): self.lock = threading.Lock() SerializableLock._locks[self.token] = self.lock - def acquire(self, *args, **kwargs): + def acquire(self, *args, **kwargs) -> bool: return self.lock.acquire(*args, **kwargs) - def release(self, *args, **kwargs): - return self.lock.release(*args, **kwargs) + def release(self, *args, **kwargs) -> None: + self.lock.release(*args, **kwargs) - def __enter__(self): - self.lock.__enter__() + def __enter__(self) -> bool: + return self.lock.__enter__() - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.lock.__exit__(*args) - def locked(self): + def locked(self) -> bool: return self.lock.locked() - def __getstate__(self): + def __getstate__(self) -> Hashable: return self.token - def __setstate__(self, token): - self.__init__(token) + def __setstate__(self, token: Hashable) -> None: + self._set_token_and_lock(token) - def __str__(self): - return f"<{self.__class__.__name__}: {self.token}>" + def __str__(self) -> str: + return f"<{type(self).__name__}: {self.token}>" - __repr__ = __str__ + def __repr__(self) -> str: + return f"{type(self).__name__}({self.token!r})" # Locks used by multiple backends. @@ -87,10 +97,12 @@ def __str__(self): NETCDFC_LOCK = SerializableLock() -_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary() +_FILE_LOCKS: weakref.WeakValueDictionary[ + Hashable, threading.Lock +] = weakref.WeakValueDictionary() -def _get_threaded_lock(key): +def _get_threaded_lock(key: Hashable) -> threading.Lock: try: lock = _FILE_LOCKS[key] except KeyError: @@ -98,14 +110,17 @@ def _get_threaded_lock(key): return lock -def _get_multiprocessing_lock(key): +def _get_multiprocessing_lock(key: Hashable) -> LockLike: # TODO: make use of the key -- maybe use locket.py? # https://github.com/mwilliamson/locket.py del key # unused - return multiprocessing.Lock() + # multiprocessing.Lock is missing the "locked" method??? + return multiprocessing.Lock() # type: ignore[return-value] -def _get_lock_maker(scheduler=None): +def _get_lock_maker( + scheduler: SchedulerOptions = None, +) -> Callable[[Hashable], LockLike] | None: """Returns an appropriate function for creating resource locks. Parameters @@ -120,23 +135,23 @@ def _get_lock_maker(scheduler=None): if scheduler is None: return _get_threaded_lock - elif scheduler == "threaded": + if scheduler == "threaded": return _get_threaded_lock - elif scheduler == "multiprocessing": + if scheduler == "multiprocessing": return _get_multiprocessing_lock - elif scheduler == "distributed": + if scheduler == "distributed": # Lazy import distributed since it is can add a significant # amount of time to import try: from dask.distributed import Lock as DistributedLock + + return DistributedLock except ImportError: - DistributedLock = None - return DistributedLock - else: - raise KeyError(scheduler) + return None + raise KeyError(scheduler) -def _get_scheduler(get=None, collection=None) -> str | None: +def _get_scheduler(get=None, collection=None) -> SchedulerOptions: """Determine the dask scheduler that is being used. None is returned if no dask scheduler is active. @@ -174,12 +189,12 @@ def _get_scheduler(get=None, collection=None) -> str | None: return "threaded" -def get_write_lock(key): +def get_write_lock(key: Hashable) -> LockLike: """Get a scheduler appropriate lock for writing to the given resource. Parameters ---------- - key : str + key : hashable Name of the resource for which to acquire a lock. Typically a filename. Returns @@ -188,10 +203,11 @@ def get_write_lock(key): """ scheduler = _get_scheduler() lock_maker = _get_lock_maker(scheduler) + assert lock_maker is not None return lock_maker(key) -def acquire(lock, blocking=True): +def acquire(lock: LockLike, blocking: bool = True) -> bool: """Acquire a lock, possibly in a non-blocking fashion. Includes backwards compatibility hacks for old versions of Python, dask @@ -216,29 +232,30 @@ class CombinedLock: locks are locked. """ - def __init__(self, locks): + locks: tuple[LockLike, ...] + + def __init__(self, locks: Iterable[LockLike]): self.locks = tuple(set(locks)) # remove duplicates - def acquire(self, blocking=True): + def acquire(self, blocking: bool = True) -> bool: return all(acquire(lock, blocking=blocking) for lock in self.locks) - def release(self): + def release(self) -> None: for lock in self.locks: lock.release() - def __enter__(self): - for lock in self.locks: - lock.__enter__() + def __enter__(self) -> bool: + return all(lock.__enter__() for lock in self.locks) - def __exit__(self, *args): + def __exit__(self, *args) -> None: for lock in self.locks: lock.__exit__(*args) - def locked(self): + def locked(self) -> bool: return any(lock.locked for lock in self.locks) - def __repr__(self): - return f"CombinedLock({list(self.locks)!r})" + def __repr__(self) -> str: + return f"{type(self).__name__}({list(self.locks)!r})" class DummyLock: @@ -260,9 +277,9 @@ def locked(self): return False -def combine_locks(locks): +def combine_locks(locks: Iterable[LockLike]) -> LockLike: """Combine a sequence of locks into a single lock.""" - all_locks = [] + all_locks: list[LockLike] = [] for lock in locks: if isinstance(lock, CombinedLock): all_locks.extend(lock.locks) @@ -272,13 +289,22 @@ def combine_locks(locks): num_locks = len(all_locks) if num_locks > 1: return CombinedLock(all_locks) - elif num_locks == 1: + if num_locks == 1: return all_locks[0] - else: - return DummyLock() + return DummyLock() + + +@overload +def ensure_lock(lock: Literal[False] | None) -> DummyLock: + ... + + +@overload +def ensure_lock(lock: T_LockLike) -> T_LockLike: + ... -def ensure_lock(lock): +def ensure_lock(lock: Literal[False] | T_LockLike | None) -> T_LockLike | DummyLock: """Ensure that the given object is a lock.""" if lock is None or lock is False: return DummyLock() diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index cf753828242..e9aa044746e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -5,7 +5,8 @@ import os from collections.abc import Iterable from contextlib import suppress -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np @@ -31,6 +32,7 @@ from xarray.backends.store import StoreBackendEntrypoint from xarray.coding.variables import pop_to from xarray.core import indexing +from xarray.core.types import Self from xarray.core.utils import ( FrozenDict, close_on_error, @@ -40,10 +42,16 @@ from xarray.core.variable import Variable if TYPE_CHECKING: - from io import BufferedIOBase + import netCDF4 - from xarray.backends.common import AbstractDataStore + from xarray.backends.file_manager import FileManager from xarray.core.dataset import Dataset + from xarray.core.types import ( + LockLike, + NetcdfFormats, + NetCDFOpenModes, + T_XarrayCanOpen, + ) # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. @@ -168,11 +176,20 @@ def _nc4_dtype(var): return dtype -def _netcdf4_create_group(dataset, name): +def _netcdf4_create_group( + dataset: netCDF4.Dataset | netCDF4.Group, name: str +) -> netCDF4.Group: return dataset.createGroup(name) -def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): +def _nc4_require_group( + ds: netCDF4.Dataset, + group: str | None, + mode: str | None, + create_group: Callable[ + [netCDF4.Dataset | netCDF4.Group, str], netCDF4.Group + ] = _netcdf4_create_group, +) -> netCDF4.Dataset | netCDF4.Group: if group in {None, "", "/"}: # use the root group return ds @@ -320,18 +337,32 @@ class NetCDF4DataStore(WritableCFDataStore): """ __slots__ = ( - "autoclose", + "_manager", + "_group", + "_mode", + "_filename", "format", "is_remote", "lock", - "_filename", - "_group", - "_manager", - "_mode", + "autoclose", ) + _manager: FileManager[netCDF4.Dataset] + _group: str | None + _mode: NetCDFOpenModes + _filename: str + format: NetcdfFormats + is_remote: bool + lock: Literal[False] | LockLike | None + autoclose: bool + def __init__( - self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False + self, + manager: netCDF4.Dataset | FileManager[netCDF4.Dataset], + group: str | None = None, + mode: NetCDFOpenModes = "r", + lock: Literal[False] | LockLike | None = NETCDF4_PYTHON_LOCK, + autoclose: bool = False, ): import netCDF4 @@ -359,17 +390,17 @@ def __init__( @classmethod def open( cls, - filename, - mode="r", - format="NETCDF4", - group=None, - clobber=True, - diskless=False, - persist=False, - lock=None, + filename: T_XarrayCanOpen, + mode: NetCDFOpenModes = "r", + format: NetcdfFormats | None = "NETCDF4", + group: str | None = None, + clobber: bool = True, + diskless: bool = False, + persist: bool = False, + lock: Literal[False] | LockLike | None = None, lock_maker=None, - autoclose=False, - ): + autoclose: bool = False, + ) -> Self: import netCDF4 if isinstance(filename, os.PathLike): @@ -405,13 +436,13 @@ def open( ) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) - def _acquire(self, needs_lock=True): + def _acquire(self, needs_lock: bool = True) -> netCDF4.Dataset | netCDF4.Group: with self._manager.acquire_context(needs_lock) as root: ds = _nc4_require_group(root, self._group, self._mode) return ds @property - def ds(self): + def ds(self) -> netCDF4.Dataset | netCDF4.Group: return self._acquire() def open_store_variable(self, name, var): @@ -534,21 +565,57 @@ def close(self, **kwargs): self._manager.close(**kwargs) +@dataclass(repr=False) class NetCDF4BackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the netCDF4 package. - It can open ".nc", ".nc4", ".cdf" files and will be chosen - as default for these files. + It can open ".nc", ".nc4", ".cdf" files and will be chosen as default for + these files. Additionally it can open valid HDF5 files, see - https://h5netcdf.org/#invalid-netcdf-files for more info. - It will not be detected as valid backend for such files, so make - sure to specify ``engine="netcdf4"`` in ``open_dataset``. + https://h5netcdf.org/#invalid-netcdf-files for more info. It will not be + detected as valid backend for such files, so make sure to specify + ``engine="netcdf4"`` in ``open_dataset``. For more information about the underlying library, visit: https://unidata.github.io/netcdf4-python + Parameters + ---------- + group: str or None, optional + Path to the netCDF4 group in the given file to open. None (default) uses + the root group. + mode: {"w", "x", "a", "r+", "r"}, default: "r" + Access mode of the NetCDF file. "r" means read-only; no data can be + modified. "w" means write; a new file is created, an existing file with + the same name is deleted. "x" means write, but fail if an existing file + with the same name already exists. "a" and "r+" mean append; an existing + file is opened for reading and writing, if file does not exist already, + one is created. + format: {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_64BIT_OFFSET", "NETCDF3_64BIT_DATA", "NETCDF3_CLASSIC"} \ + or None, optional + Format of the NetCDF file, defaults to "NETCDF4". + lock: False, None or Lock-like, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. If None (default) appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. + autoclose: bool, default: False + If True, automatically close files to avoid OS Error of too many files + being open. However, this option doesn't work with streams, e.g., + BytesIO. + clobber: bool, default: False + If True, opening a file with mode="w" will clobber an existing file with + the same name. If False, an exception will be raised if a file with the + same name already exists. mode="x" is identical to mode="w" with + clobber=False. + diskless: bool, default: False + If True, create diskless (in-core) file. + persist: bool, default: False + If True, persist file to disk when closed. + See Also -------- backends.NetCDF4DataStore @@ -560,11 +627,26 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray" ) url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html" + open_dataset_parameters = ( + "drop_variables", + "mask_and_scale", + "decode_times", + "concat_characters", + "use_cftime", + "decode_timedelta", + "decode_coords", + ) - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ) -> bool: + group: str | None = None + mode: NetCDFOpenModes = "r" + format: NetcdfFormats | None = "NETCDF4" + lock: Literal[False] | LockLike | None = None + autoclose: bool = False + clobber: bool = True + diskless: bool = False + persist: bool = False + + def guess_can_open(self, filename_or_obj: T_XarrayCanOpen) -> bool: if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True magic_number = try_read_magic_number_from_path(filename_or_obj) @@ -578,38 +660,33 @@ def guess_can_open( return False - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - format="NETCDF4", - clobber=True, - diskless=False, - persist=False, - lock=None, - autoclose=False, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, + decode_coords: bool | Literal["coordinates", "all"] = True, + **kwargs: Any, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( filename_or_obj, - mode=mode, - format=format, - group=group, - clobber=clobber, - diskless=diskless, - persist=persist, - lock=lock, - autoclose=autoclose, + mode=kwargs.pop("mode", self.mode), + format=kwargs.pop("format", self.format), + group=kwargs.pop("group", self.group), + clobber=kwargs.pop("clobber", self.clobber), + diskless=kwargs.pop("diskless", self.diskless), + persist=kwargs.pop("persist", self.persist), + lock=kwargs.pop("lock", self.lock), + autoclose=kwargs.pop("autoclose", self.autoclose), ) + if kwargs: + raise ValueError(f"Unsupported kwargs: {kwargs.values()}") store_entrypoint = StoreBackendEntrypoint() with close_on_error(store): diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index a62ca6c9862..16810e0cd41 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -6,22 +6,20 @@ import sys import warnings from importlib.metadata import entry_points -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Callable from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint from xarray.core.utils import module_available if TYPE_CHECKING: - import os from importlib.metadata import EntryPoint if sys.version_info >= (3, 10): from importlib.metadata import EntryPoints else: EntryPoints = list[EntryPoint] - from io import BufferedIOBase - from xarray.backends.common import AbstractDataStore + from xarray.core.types import T_XarrayCanOpen STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] @@ -129,9 +127,8 @@ def list_engines() -> dict[str, BackendEntrypoint]: ----- This function lives in the backends namespace (``engs=xr.backends.list_engines()``). If available, more information is available about each backend via ``engs["eng_name"]``. - - # New selection mechanism introduced with Python 3.10. See GH6514. """ + # New selection mechanism introduced with Python 3.10. See GH6514. if sys.version_info >= (3, 10): entrypoints = entry_points(group="xarray.backends") else: @@ -144,9 +141,7 @@ def refresh_engines() -> None: list_engines.cache_clear() -def guess_engine( - store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, -) -> str | type[BackendEntrypoint]: +def guess_engine(store_spec: T_XarrayCanOpen) -> str | type[BackendEntrypoint]: engines = list_engines() for engine, backend in engines.items(): @@ -197,21 +192,23 @@ def guess_engine( raise ValueError(error_msg) -def get_backend(engine: str | type[BackendEntrypoint]) -> BackendEntrypoint: +def get_backend( + engine: str | BackendEntrypoint | type[BackendEntrypoint], +) -> BackendEntrypoint: """Select open_dataset method based on current engine.""" + if isinstance(engine, BackendEntrypoint): + return engine if isinstance(engine, str): engines = list_engines() if engine not in engines: raise ValueError( f"unrecognized engine {engine} must be one of: {list(engines)}" ) - backend = engines[engine] - elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): - backend = engine() - else: - raise TypeError( - "engine must be a string or a subclass of " - f"xarray.backends.BackendEntrypoint: {engine}" - ) + return engines[engine] + if isinstance(engine, type) and issubclass(engine, BackendEntrypoint): + return engine() - return backend + raise TypeError( + "engine must be a string, a subclass of xarray.backends.BackendEntrypoint" + f" or an object of such a subclass, got {type(engine)}" + ) diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 9b5bcc82e6f..c2805f92a5a 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,7 +1,8 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -25,10 +26,11 @@ from xarray.core.variable import Variable if TYPE_CHECKING: - import os - from io import BufferedIOBase + import pydap.model + import requests from xarray.core.dataset import Dataset + from xarray.core.types import Self, T_XarrayCanOpen class PydapArrayWrapper(BackendArray): @@ -88,7 +90,9 @@ class PydapDataStore(AbstractDataStore): be useful if the netCDF4 library is not available. """ - def __init__(self, ds): + ds: pydap.model.DatasetType + + def __init__(self, ds: pydap.model.DatasetType) -> None: """ Parameters ---------- @@ -99,14 +103,14 @@ def __init__(self, ds): @classmethod def open( cls, - url, - application=None, - session=None, - output_grid=None, - timeout=None, - verify=None, - user_charset=None, - ): + url: str, + application: Any = None, + session: requests.Session | None = None, + output_grid: bool | None = None, + timeout: float | None = None, + verify: bool | None = None, + user_charset: str | None = None, + ) -> Self: import pydap.client import pydap.lib @@ -145,6 +149,7 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) +@dataclass(repr=False) class PydapBackendEntrypoint(BackendEntrypoint): """ Backend for steaming datasets over the internet using @@ -156,6 +161,16 @@ class PydapBackendEntrypoint(BackendEntrypoint): For more information about the underlying library, visit: https://www.pydap.org + Parameters + ---------- + application: + session: + output_grid: + timeout: float or None, default: 120 + Timeout in seconds. + verify: + user_charset: + See Also -------- backends.PydapDataStore @@ -163,40 +178,54 @@ class PydapBackendEntrypoint(BackendEntrypoint): description = "Open remote datasets via OPeNDAP using pydap in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html" - - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ) -> bool: + open_dataset_parameters = ( + "drop_variables", + "mask_and_scale", + "decode_times", + "concat_characters", + "use_cftime", + "decode_timedelta", + "decode_coords", + ) + + application: Any = None + session: requests.Session | None = None + output_grid: bool | None = None + timeout: float | None = None + verify: bool | None = None + user_charset: str | None = None + + def guess_can_open(self, filename_or_obj: T_XarrayCanOpen) -> bool: return isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj) - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - application=None, - session=None, - output_grid=None, - timeout=None, - verify=None, - user_charset=None, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, + decode_coords: bool | Literal["coordinates", "all"] = True, + **kwargs: Any, ) -> Dataset: + if not isinstance(filename_or_obj, str): + raise ValueError( + f"'filename_or_obj' must be a str (url), got {type(filename_or_obj)}." + ) store = PydapDataStore.open( url=filename_or_obj, - application=application, - session=session, - output_grid=output_grid, - timeout=timeout, - verify=verify, - user_charset=user_charset, + application=kwargs.pop("application", self.application), + session=kwargs.pop("session", self.session), + output_grid=kwargs.pop("output_grid", self.output_grid), + timeout=kwargs.pop("timeout", self.timeout), + verify=kwargs.pop("verify", self.verify), + user_charset=kwargs.pop("user_charset", self.user_charset), ) + if kwargs: + raise ValueError(f"Unsupported kwargs: {kwargs.values()}") store_entrypoint = StoreBackendEntrypoint() with close_on_error(store): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 75e96ffdc0a..0c50596c397 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import numpy as np @@ -27,10 +27,8 @@ from xarray.core.variable import Variable if TYPE_CHECKING: - import os - from io import BufferedIOBase - from xarray.core.dataset import Dataset + from xarray.core.types import T_XarrayCanOpen # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. @@ -125,9 +123,9 @@ class PynioBackendEntrypoint(BackendEntrypoint): https://github.com/pydata/xarray/issues/4491 for more information """ - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, mask_and_scale=True, decode_times=True, diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 1ecc70cf376..c3097f9d052 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -4,7 +4,8 @@ import io import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -33,10 +34,16 @@ from xarray.core.variable import Variable if TYPE_CHECKING: - from io import BufferedIOBase + import scipy.io - from xarray.backends.common import AbstractDataStore + from xarray.backends import FileManager from xarray.core.dataset import Dataset + from xarray.core.types import ( + LockLike, + ScipyFormats, + ScipyOpenModes, + T_XarrayCanOpen, + ) def _decode_string(s): @@ -84,7 +91,7 @@ def __setitem__(self, key, value): raise -def _open_scipy_netcdf(filename, mode, mmap, version): +def _open_scipy_netcdf(filename, mode, mmap, version) -> scipy.io.netcdf_file: import scipy.io # if the string ends with .gz, then gunzip and open as netcdf file @@ -131,13 +138,22 @@ class ScipyDataStore(WritableCFDataStore): It only supports the NetCDF3 file-format. """ + _manager: FileManager[scipy.io.netcdf_file] + lock: LockLike + def __init__( - self, filename_or_obj, mode="r", format=None, group=None, mmap=None, lock=None - ): + self, + filename_or_obj: T_XarrayCanOpen, + mode: ScipyOpenModes = "r", + format: ScipyFormats = None, + group: str | None = None, + mmap: bool | None = None, + lock: Literal[False] | LockLike | None = None, + ) -> None: if group is not None: raise ValueError("cannot save to a group with the scipy.io.netcdf backend") - if format is None or format == "NETCDF3_64BIT": + if format in (None, "NETCDF3_64BIT", "NETCDF3_64BIT_OFFSET"): version = 2 elif format == "NETCDF3_CLASSIC": version = 1 @@ -150,7 +166,7 @@ def __init__( self.lock = ensure_lock(lock) if isinstance(filename_or_obj, str): - manager = CachingFileManager( + self._manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, mode=mode, @@ -161,12 +177,10 @@ def __init__( scipy_dataset = _open_scipy_netcdf( filename_or_obj, mode=mode, mmap=mmap, version=version ) - manager = DummyFileManager(scipy_dataset) - - self._manager = manager + self._manager = DummyFileManager(scipy_dataset) @property - def ds(self): + def ds(self) -> scipy.io.netcdf_file: return self._manager.acquire() def open_store_variable(self, name, var): @@ -247,34 +261,72 @@ def close(self): self._manager.close() +@dataclass(repr=False) class ScipyBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the scipy package. - It can open ".nc", ".nc4", ".cdf" and ".gz" files but will only be - selected as the default if the "netcdf4" and "h5netcdf" engines are - not available. It has the advantage that is is a lightweight engine - that has no system requirements (unlike netcdf4 and h5netcdf). + It can open ".nc", ".nc4", ".cdf" and ".gz" files but will only be selected + as the default if the "netcdf4" and "h5netcdf" engines are not available. It + has the advantage that is is a lightweight engine that has no system + requirements (unlike netcdf4 and h5netcdf). Additionally it can open gizp compressed (".gz") files. For more information about the underlying library, visit: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.netcdf_file.html + Parameters + ---------- + group: str or None, optional + Path to the netCDF4 group in the given file to open. None (default) uses + the root group. + mode: {"w", "a", "r"}, default: "r" + Access mode of the NetCDF file. "r" means read-only; no data can be + modified. "w" means write; a new file is created, an existing file with + the same name is deleted. "a" means append; an existing file is opened + for reading and writing, if file does not exist already, one is created. + format: {"NETCDF3_64BIT", "NETCDF3_64BIT_OFFSET", "NETCDF3_CLASSIC"} or \ + None, optional + Format of the NetCDF file. Only classic NetCDF files supported. For newer + NetCDF version use a different backend. + lock: False, None or Lock-like, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. If None (default) appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. + mmap: bool or None, optional + Whether to mmap filename when reading. Default is True when filename is + a file name, False when filename is a file-like object. Note that when + mmap is in use, data arrays returned refer directly to the mmapped data + on disk, and the file cannot be closed as long as references to it + exist. + See Also -------- - backends.ScipyDataStore - backends.NetCDF4BackendEntrypoint + backends.ScipyDataStore backends.NetCDF4BackendEntrypoint backends.H5netcdfBackendEntrypoint """ description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html" - - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ) -> bool: + open_dataset_parameters = ( + "drop_variables", + "mask_and_scale", + "decode_times", + "concat_characters", + "use_cftime", + "decode_timedelta", + "decode_coords", + ) + + group: str | None = None + mode: ScipyOpenModes = "r" + format: ScipyFormats = None + lock: Literal[False] | LockLike | None = None + mmap: bool | None = None + + def guess_can_open(self, filename_or_obj: T_XarrayCanOpen) -> bool: magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): with gzip.open(filename_or_obj) as f: # type: ignore[arg-type] @@ -288,27 +340,30 @@ def guess_can_open( return False - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - mode="r", - format=None, - group=None, - mmap=None, - lock=None, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, + decode_coords: bool | Literal["coordinates", "all"] = True, + **kwargs: Any, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = ScipyDataStore( - filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock + filename_or_obj, + mode=kwargs.pop("mode", self.mode), + format=kwargs.pop("format", self.format), + group=kwargs.pop("group", self.group), + mmap=kwargs.pop("mmap", self.mmap), + lock=kwargs.pop("lock", self.lock), ) + if kwargs: + raise ValueError(f"Unsupported kwargs: {kwargs.values()}") store_entrypoint = StoreBackendEntrypoint() with close_on_error(store): diff --git a/xarray/backends/store.py b/xarray/backends/store.py index a507ee37470..15ce0f74489 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Literal from xarray import conventions from xarray.backends.common import ( @@ -12,31 +12,36 @@ from xarray.core.dataset import Dataset if TYPE_CHECKING: - import os - from io import BufferedIOBase + from xarray.core.types import T_XarrayCanOpen class StoreBackendEntrypoint(BackendEntrypoint): description = "Open AbstractDataStore instances in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html" + open_dataset_parameters = ( + "drop_variables", + "mask_and_scale", + "decode_times", + "concat_characters", + "use_cftime", + "decode_timedelta", + "decode_coords", + ) - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ) -> bool: + def guess_can_open(self, filename_or_obj: T_XarrayCanOpen) -> bool: return isinstance(filename_or_obj, AbstractDataStore) - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, + decode_coords: bool | Literal["coordinates", "all"] = True, ) -> Dataset: assert isinstance(filename_or_obj, AbstractDataStore) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 469bbf4c339..a412a97aba5 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -3,8 +3,9 @@ import json import os import warnings -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable, Mapping, MutableMapping +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -21,7 +22,7 @@ from xarray.core import indexing from xarray.core.parallelcompat import guess_chunkmanager from xarray.core.pycompat import integer_types -from xarray.core.types import ZarrWriteModes +from xarray.core.types import ZarrOpenModes from xarray.core.utils import ( FrozenDict, HiddenKeyDict, @@ -30,10 +31,10 @@ from xarray.core.variable import Variable if TYPE_CHECKING: - from io import BufferedIOBase + import zarr - from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset + from xarray.core.types import Self, T_Chunks, T_XarrayCanOpen # need some special secret attributes to tell us the dimensions @@ -319,7 +320,11 @@ def encode_zarr_variable(var, needs_copy=True, name=None): def _validate_and_transpose_existing_dims( - var_name, new_var, existing_var, region, append_dim + var_name, + new_var, + existing_var, + region: Mapping[str, slice] | None, + append_dim: str, ): if new_var.dims != existing_var.dims: if set(existing_var.dims) == set(new_var.dims): @@ -381,25 +386,38 @@ class ZarrStore(AbstractWritableDataStore): "_write_empty", "_close_store_on_close", ) + zarr_group: zarr.Group + _append_dim: str | None + _consolidate_on_close: bool + _group: str + _mode: ZarrOpenModes + _read_only: bool + _synchronizer: object | None + _write_region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None + _safe_chunks: bool + _write_empty: bool | None + _close_store_on_close: bool @classmethod def open_group( cls, - store, - mode: ZarrWriteModes = "r", - synchronizer=None, - group=None, - consolidated=False, - consolidate_on_close=False, - chunk_store=None, - storage_options=None, - append_dim=None, - write_region=None, - safe_chunks=True, - stacklevel=2, - zarr_version=None, + store: T_XarrayCanOpen | MutableMapping | None, + mode: ZarrOpenModes = "r", + synchronizer: object | None = None, + group: str | None = None, + consolidated: bool | None = False, + consolidate_on_close: bool = False, + chunk_store: MutableMapping | str | os.PathLike | None = None, + storage_options: Mapping[str, Any] | None = None, + append_dim: str | None = None, + write_region: Mapping[str, slice | Literal["auto"]] + | Literal["auto"] + | None = None, + safe_chunks: bool = True, + stacklevel: int = 2, + zarr_version: int | None = None, write_empty: bool | None = None, - ): + ) -> Self: import zarr # zarr doesn't support pathlib.Path objects yet. zarr-python#601 @@ -458,7 +476,7 @@ def open_group( stacklevel=stacklevel, ) except zarr.errors.GroupNotFoundError: - raise FileNotFoundError(f"No such file or directory: '{store}'") + raise FileNotFoundError(f"No such file or directory: '{store!r}'") elif consolidated: # TODO: an option to pass the metadata_key keyword zarr_group = zarr.open_consolidated(store, **open_kwargs) @@ -478,15 +496,17 @@ def open_group( def __init__( self, - zarr_group, - mode=None, - consolidate_on_close=False, - append_dim=None, - write_region=None, - safe_chunks=True, + zarr_group: zarr.Group, + mode: ZarrOpenModes = "r", + consolidate_on_close: bool = False, + append_dim: str | None = None, + write_region: Mapping[str, slice | Literal["auto"]] + | Literal["auto"] + | None = None, + safe_chunks: bool = True, write_empty: bool | None = None, close_store_on_close: bool = False, - ): + ) -> None: self.zarr_group = zarr_group self._read_only = self.zarr_group.read_only self._synchronizer = self.zarr_group.synchronizer @@ -500,7 +520,7 @@ def __init__( self._close_store_on_close = close_store_on_close @property - def ds(self): + def ds(self) -> zarr.Group: # TODO: consider deprecating this in favor of zarr_group return self.zarr_group @@ -785,26 +805,26 @@ def close(self): def open_zarr( store, - group=None, - synchronizer=None, - chunks="auto", - decode_cf=True, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables=None, - consolidated=None, - overwrite_encoded_chunks=False, - chunk_store=None, - storage_options=None, - decode_timedelta=None, - use_cftime=None, - zarr_version=None, + group: str | None = None, + synchronizer: object | None = None, + chunks: T_Chunks = "auto", + decode_cf: bool = True, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, + drop_variables: str | Iterable[str] | None = None, + consolidated: bool | None = None, + overwrite_encoded_chunks: bool = False, + chunk_store: MutableMapping | str | os.PathLike | None = None, + storage_options: Mapping[str, Any] | None = None, + decode_timedelta: bool | None = None, + use_cftime: bool | None = None, + zarr_version: int | None = None, chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, - **kwargs, -): + **kwargs: Any, +) -> Dataset: """Load and decode a dataset from a Zarr store. The `store` object should be a valid store for a Zarr group. `store` @@ -931,15 +951,15 @@ def open_zarr( "open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys()) ) - backend_kwargs = { - "synchronizer": synchronizer, - "consolidated": consolidated, - "overwrite_encoded_chunks": overwrite_encoded_chunks, - "chunk_store": chunk_store, - "storage_options": storage_options, - "stacklevel": 4, - "zarr_version": zarr_version, - } + zarr_backend = ZarrBackendEntrypoint( + group=group, + synchronizer=synchronizer, + consolidated=consolidated, + chunk_store=chunk_store, + storage_options=storage_options, + stacklevel=4, + zarr_version=zarr_version, + ) ds = open_dataset( filename_or_obj=store, @@ -949,19 +969,20 @@ def open_zarr( decode_times=decode_times, concat_characters=concat_characters, decode_coords=decode_coords, - engine="zarr", + engine=zarr_backend, chunks=chunks, drop_variables=drop_variables, chunked_array_type=chunked_array_type, from_array_kwargs=from_array_kwargs, - backend_kwargs=backend_kwargs, decode_timedelta=decode_timedelta, use_cftime=use_cftime, zarr_version=zarr_version, + overwrite_encoded_chunks=overwrite_encoded_chunks, ) return ds +@dataclass(repr=False) class ZarrBackendEntrypoint(BackendEntrypoint): """ Backend for ".zarr" files based on the zarr package. @@ -976,50 +997,64 @@ class ZarrBackendEntrypoint(BackendEntrypoint): description = "Open zarr files (.zarr) using zarr in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html" + open_dataset_parameters = ( + "drop_variables", + "mask_and_scale", + "decode_times", + "concat_characters", + "use_cftime", + "decode_timedelta", + "decode_coords", + ) - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ) -> bool: + group: str | None = None + mode: ZarrOpenModes = "r" + synchronizer: object | None = None + consolidated: bool | None = None + chunk_store: MutableMapping | str | os.PathLike | None = None + storage_options: Mapping[str, Any] | None = None + stacklevel: int = 3 + zarr_version: int | None = None + + def guess_can_open(self, filename_or_obj: T_XarrayCanOpen) -> bool: if isinstance(filename_or_obj, (str, os.PathLike)): _, ext = os.path.splitext(filename_or_obj) return ext in {".zarr"} return False - def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs + def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + filename_or_obj: T_XarrayCanOpen | MutableMapping | None, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - synchronizer=None, - consolidated=None, - chunk_store=None, - storage_options=None, - stacklevel=3, - zarr_version=None, + mask_and_scale: bool = True, + decode_times: bool = True, + concat_characters: bool = True, + use_cftime: bool | None = None, + decode_timedelta: bool | None = None, + decode_coords: bool | Literal["coordinates", "all"] = True, + **kwargs: Any, ) -> Dataset: + if "auto_chunk" in kwargs: + raise TypeError( + "open_dataset got an unexpected keyword argument 'auto_chunk'" + ) filename_or_obj = _normalize_path(filename_or_obj) store = ZarrStore.open_group( filename_or_obj, - group=group, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, + group=kwargs.pop("group", self.group), + mode=kwargs.pop("mode", self.mode), + synchronizer=kwargs.pop("synchronizer", self.synchronizer), + consolidated=kwargs.pop("consolidated", self.consolidated), consolidate_on_close=False, - chunk_store=chunk_store, - storage_options=storage_options, - stacklevel=stacklevel + 1, - zarr_version=zarr_version, + chunk_store=kwargs.pop("chunk_store", self.chunk_store), + storage_options=kwargs.pop("storage_options", self.storage_options), + stacklevel=kwargs.pop("stacklevel", self.stacklevel) + 1, + zarr_version=kwargs.pop("zarr_version", self.zarr_version), ) + if kwargs: + raise ValueError(f"Unsupported kwargs: {kwargs.values()}") store_entrypoint = StoreBackendEntrypoint() with close_on_error(store): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d2ea0f8a1a4..9d07d516155 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -55,7 +55,7 @@ DaCompatible, T_DataArray, T_DataArrayOrSet, - ZarrWriteModes, + ZarrOpenModes, ) from xarray.core.utils import ( Default, @@ -82,7 +82,7 @@ from numpy.typing import ArrayLike from xarray.backends import ZarrStore - from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes + from xarray.backends.api import NetcdfFormats, T_NetcdfEngine from xarray.core.groupby import DataArrayGroupBy from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.resample import DataArrayResample @@ -3908,7 +3908,7 @@ def to_netcdf( self, path: None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -3924,7 +3924,7 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -3941,7 +3941,7 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -3958,7 +3958,7 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -3972,7 +3972,7 @@ def to_netcdf( self, path: str | PathLike | None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, @@ -4100,7 +4100,7 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, *, @@ -4121,7 +4121,7 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, @@ -4140,7 +4140,7 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c83a56bb373..057919bf121 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -101,7 +101,7 @@ T_DataArray, T_DataArrayOrSet, T_Dataset, - ZarrWriteModes, + ZarrOpenModes, ) from xarray.core.utils import ( Default, @@ -135,7 +135,7 @@ from numpy.typing import ArrayLike from xarray.backends import AbstractDataStore, ZarrStore - from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes + from xarray.backends.api import NetcdfFormats, T_NetcdfEngine from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult @@ -2145,7 +2145,7 @@ def to_netcdf( self, path: None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, @@ -2161,7 +2161,7 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, @@ -2178,7 +2178,7 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, @@ -2195,7 +2195,7 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, @@ -2209,7 +2209,7 @@ def to_netcdf( self, path: str | PathLike | None = None, mode: Literal["w", "a"] = "w", - format: T_NetcdfTypes | None = None, + format: NetcdfFormats | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, @@ -2320,7 +2320,7 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, @@ -2343,7 +2343,7 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: ZarrWriteModes | None = None, + mode: ZarrOpenModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, @@ -2364,8 +2364,8 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: ZarrWriteModes | None = None, - synchronizer=None, + mode: ZarrOpenModes | None = None, + synchronizer: object | None = None, group: str | None = None, encoding: Mapping | None = None, *, diff --git a/xarray/core/types.py b/xarray/core/types.py index 06ad65679d8..fbd9a031412 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,8 +1,10 @@ from __future__ import annotations import datetime +import os import sys from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -27,12 +29,17 @@ raise else: Self: Any = None + TypeAlias: Any = None if TYPE_CHECKING: + from io import BufferedIOBase + + from cftime import datetime as CFTimeDatetime + from dask.array import Array as DaskArray from numpy._typing import _SupportsDType from numpy.typing import ArrayLike - from xarray.backends.common import BackendEntrypoint + from xarray.backends.common import AbstractDataStore, BackendEntrypoint from xarray.core.alignment import Aligner from xarray.core.common import AbstractArray, DataWithCoords from xarray.core.coordinates import Coordinates @@ -42,21 +49,6 @@ from xarray.core.utils import Frozen from xarray.core.variable import Variable - try: - from dask.array import Array as DaskArray - except ImportError: - DaskArray = np.ndarray # type: ignore - - try: - from cubed import Array as CubedArray - except ImportError: - CubedArray = np.ndarray - - try: - from zarr.core import Array as ZarrArray - except ImportError: - ZarrArray = np.ndarray - # Anything that can be coerced to a shape tuple _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] _DTypeLikeNested = Any # TODO: wait for support for recursive types @@ -84,10 +76,6 @@ # anything with a dtype attribute _SupportsDType[np.dtype[Any]], ] - try: - from cftime import datetime as CFTimeDatetime - except ImportError: - CFTimeDatetime = Any DatetimeLike = Union[pd.Timestamp, datetime.datetime, np.datetime64, CFTimeDatetime] else: DTypeLikeSave: Any = None @@ -268,7 +256,7 @@ def copy( ] -QuantileMethods = Literal[ +QuantileMethods: TypeAlias = Literal[ "inverted_cdf", "averaged_inverted_cdf", "closest_observation", @@ -284,5 +272,65 @@ def copy( "nearest", ] +T_XarrayCanOpen: TypeAlias = Union[ + str, bytes, os.PathLike[Any], "BufferedIOBase", "AbstractDataStore" +] +NetcdfFormats = Literal[ + "NETCDF4", + "NETCDF4_CLASSIC", + "NETCDF3_64BIT", + "NETCDF3_64BIT_OFFSET", + "NETCDF3_64BIT_DATA", + "NETCDF3_CLASSIC", +] +ScipyFormats = Literal["NETCDF3_64BIT", "NETCDF3_64BIT_OFFSET", "NETCDF3_CLASSIC", None] +NetCDFOpenModes: TypeAlias = Literal["w", "x", "a", "r+", "r"] +H5netcdfOpenModes: TypeAlias = Literal["w", "a", "r+", "r"] +ScipyOpenModes: TypeAlias = Literal["w", "a", "r"] +ZarrOpenModes: TypeAlias = Literal["w", "w-", "a", "a-", "r+", "r"] + + +class FileLike(Protocol): + def close(self) -> None: + ... + + +T_FileLike = TypeVar("T_FileLike", bound=FileLike) + + +class LockLike(Protocol): + def acquire(self, blocking: bool = True) -> bool: + ... + + def release(self) -> None: + ... + + def locked(self) -> bool: + ... + + def __enter__(self) -> bool: + ... + + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + ... + + +T_LockLike = TypeVar("T_LockLike", bound=LockLike) + + +class BackendDatasetLike(Protocol): + @property + def parent(self) -> Self | None: + ... + + @property + def name(self) -> str: + ... + -ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] +T_BackendDatasetLike = TypeVar("T_BackendDatasetLike", bound=BackendDatasetLike) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 104b6d0867d..78b69ef3dcc 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -46,8 +46,9 @@ NetCDF4BackendEntrypoint, _extract_nc4_variable_encoding, ) -from xarray.backends.pydap_ import PydapDataStore +from xarray.backends.pydap_ import PydapBackendEntrypoint, PydapDataStore from xarray.backends.scipy_ import ScipyBackendEntrypoint +from xarray.backends.zarr import ZarrBackendEntrypoint from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates @@ -124,7 +125,10 @@ dask_array_type = array_type("dask") if TYPE_CHECKING: - from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes + import pydap.model + + from xarray.backends.api import NetcdfFormats, T_NetcdfEngine + from xarray.core.types import T_XarrayCanOpen def open_example_dataset(name, *args, **kwargs) -> Dataset: @@ -268,7 +272,7 @@ def __getitem__(self, key): class NetCDF3Only: - netcdf3_formats: tuple[T_NetcdfTypes, ...] = ("NETCDF3_CLASSIC", "NETCDF3_64BIT") + netcdf3_formats: tuple[NetcdfFormats, ...] = ("NETCDF3_CLASSIC", "NETCDF3_64BIT") @requires_scipy def test_dtype_coercion_error(self) -> None: @@ -293,15 +297,19 @@ def test_dtype_coercion_error(self) -> None: class DatasetIOBase: engine: T_NetcdfEngine | None = None - file_format: T_NetcdfTypes | None = None + file_format: NetcdfFormats | None = None def create_store(self): raise NotImplementedError() @contextlib.contextmanager def roundtrip( - self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False - ): + self, + data, + save_kwargs: dict[str, Any] | None = None, + open_kwargs: dict[str, Any] | None = None, + allow_cleanup_failure: bool = False, + ) -> Iterator[Dataset]: if save_kwargs is None: save_kwargs = {} if open_kwargs is None: @@ -313,8 +321,12 @@ def roundtrip( @contextlib.contextmanager def roundtrip_append( - self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False - ): + self, + data, + save_kwargs=None, + open_kwargs=None, + allow_cleanup_failure: bool = False, + ) -> Iterator[Dataset]: if save_kwargs is None: save_kwargs = {} if open_kwargs is None: @@ -333,7 +345,7 @@ def save(self, dataset, path, **kwargs): ) @contextlib.contextmanager - def open(self, path, **kwargs): + def open(self, path: T_XarrayCanOpen, **kwargs) -> Iterator[Dataset]: with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds @@ -639,13 +651,13 @@ def test_orthogonal_indexing(self) -> None: with self.roundtrip(in_memory) as on_disk: indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)} expected = in_memory.isel(indexers) - actual = on_disk.isel(**indexers) + actual = on_disk.isel(indexers) # make sure the array is not yet loaded into memory assert not actual["var1"].variable._in_memory assert_identical(expected, actual) # do it twice, to make sure we're switched from orthogonal -> numpy # when we cached the values - actual = on_disk.isel(**indexers) + actual = on_disk.isel(indexers) assert_identical(expected, actual) def test_vectorized_indexing(self) -> None: @@ -656,13 +668,13 @@ def test_vectorized_indexing(self) -> None: "dim2": DataArray([0, 2, 3], dims="a"), } expected = in_memory.isel(indexers) - actual = on_disk.isel(**indexers) + actual = on_disk.isel(indexers) # make sure the array is not yet loaded into memory assert not actual["var1"].variable._in_memory assert_identical(expected, actual.load()) # do it twice, to make sure we're switched from # vectorized -> numpy when we cached the values - actual = on_disk.isel(**indexers) + actual = on_disk.isel(indexers) assert_identical(expected, actual) def multiple_indexing(indexers): @@ -769,7 +781,7 @@ def test_isel_dataarray(self) -> None: actual = on_disk.isel(dim2=on_disk["dim2"] < 3) assert_identical(expected, actual) - def validate_array_type(self, ds): + def validate_array_type(self, ds: Dataset) -> None: # Make sure that only NumpyIndexingAdapter stores a bare np.ndarray. def find_and_validate_array(obj): # recursively called function. obj: array or array wrapper. @@ -795,12 +807,12 @@ def test_array_type_after_indexing(self) -> None: self.validate_array_type(on_disk) indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)} expected = in_memory.isel(indexers) - actual = on_disk.isel(**indexers) + actual = on_disk.isel(indexers) assert_identical(expected, actual) self.validate_array_type(actual) # do it twice, to make sure we're switched from orthogonal -> numpy # when we cached the values - actual = on_disk.isel(**indexers) + actual = on_disk.isel(indexers) assert_identical(expected, actual) self.validate_array_type(actual) @@ -1935,6 +1947,15 @@ def test_write_inconsistent_chunks(self) -> None: assert actual["y"].encoding["chunksizes"] == (100, 50) +@requires_netCDF4 +class TestNetCDF4Instance(TestNetCDF4Data): + @contextlib.contextmanager + def open(self, path: T_XarrayCanOpen, **kwargs) -> Iterator[Dataset]: + engine = NetCDF4BackendEntrypoint() + with open_dataset(path, engine=engine, **kwargs) as ds: + yield ds + + @requires_zarr class ZarrBase(CFEncodedBase): DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -3001,6 +3022,17 @@ def create_zarr_target(self): yield tmp +@requires_zarr +class TestZarrrInstance(TestZarrWriteEmpty): + @contextlib.contextmanager + def open(self, store_target, **kwargs): + engine = ZarrBackendEntrypoint() + with xr.open_dataset( + store_target, engine=engine, **kwargs, **self.version_kwargs + ) as ds: + yield ds + + @requires_zarr @requires_fsspec def test_zarr_storage_options() -> None: @@ -3103,10 +3135,19 @@ def test_nc4_scipy(self) -> None: open_dataset(tmp_file, engine="scipy") +@requires_scipy +class TestScipyInstance(TestScipyFileObject): + @contextlib.contextmanager + def open(self, path: T_XarrayCanOpen, **kwargs) -> Iterator[Dataset]: + engine = ScipyBackendEntrypoint() + with open_dataset(path, engine=engine, **kwargs) as ds: + yield ds + + @requires_netCDF4 class TestNetCDF3ViaNetCDF4Data(CFEncodedBase, NetCDF3Only): engine: T_NetcdfEngine = "netcdf4" - file_format: T_NetcdfTypes = "NETCDF3_CLASSIC" + file_format: NetcdfFormats = "NETCDF3_CLASSIC" @contextlib.contextmanager def create_store(self): @@ -3127,7 +3168,7 @@ def test_encoding_kwarg_vlen_string(self) -> None: @requires_netCDF4 class TestNetCDF4ClassicViaNetCDF4Data(CFEncodedBase, NetCDF3Only): engine: T_NetcdfEngine = "netcdf4" - file_format: T_NetcdfTypes = "NETCDF4_CLASSIC" + file_format: NetcdfFormats = "NETCDF4_CLASSIC" @contextlib.contextmanager def create_store(self): @@ -3142,7 +3183,7 @@ def create_store(self): class TestGenericNetCDFData(CFEncodedBase, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed - file_format: T_NetcdfTypes = "NETCDF3_64BIT" + file_format: NetcdfFormats = "NETCDF3_64BIT" def test_write_store(self) -> None: # there's no specific store to test here @@ -3442,15 +3483,15 @@ class TestH5NetCDFFileObject(TestH5NetCDFData): def test_open_badbytes(self) -> None: with pytest.raises(ValueError, match=r"HDF5 as bytes"): - with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"): # type: ignore[arg-type] + with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"): pass with pytest.raises( ValueError, match=r"match in any of xarray's currently installed IO" ): - with open_dataset(b"garbage"): # type: ignore[arg-type] + with open_dataset(b"garbage"): pass with pytest.raises(ValueError, match=r"can only read bytes"): - with open_dataset(b"garbage", engine="netcdf4"): # type: ignore[arg-type] + with open_dataset(b"garbage", engine="netcdf4"): pass with pytest.raises( ValueError, match=r"not the signature of a valid netCDF4 file" @@ -3571,6 +3612,15 @@ def test_get_variable_list_empty_driver_kwds(self) -> None: assert "Temperature" in list(actual) +@requires_h5netcdf +class TestH5NetCDFInstance(TestH5NetCDFData): + @contextlib.contextmanager + def open(self, path: T_XarrayCanOpen, **kwargs) -> Iterator[Dataset]: + engine = H5netcdfBackendEntrypoint() + with open_dataset(path, engine=engine, **kwargs) as ds: + yield ds + + @pytest.fixture(params=["scipy", "netcdf4", "h5netcdf", "pynio", "zarr"]) def readengine(request): return request.param @@ -4432,7 +4482,7 @@ def num_graph_nodes(obj): @requires_pydap @pytest.mark.filterwarnings("ignore:The binary mode of fromstring is deprecated") class TestPydap: - def convert_to_pydap_dataset(self, original): + def convert_to_pydap_dataset(self, original: Dataset) -> pydap.model.DatasetType: from pydap.model import BaseType, DatasetType, GridType ds = DatasetType("bears", **original.attrs) @@ -4450,10 +4500,10 @@ def convert_to_pydap_dataset(self, original): return ds @contextlib.contextmanager - def create_datasets(self, **kwargs): + def create_datasets(self, **kwargs) -> Iterator[tuple[Dataset, Dataset]]: with open_example_dataset("bears.nc") as expected: pydap_ds = self.convert_to_pydap_dataset(expected) - actual = open_dataset(PydapDataStore(pydap_ds)) + actual = open_dataset(PydapDataStore(pydap_ds), **kwargs) # TODO solve this workaround: # netcdf converts string to byte not unicode expected["bears"] = expected["bears"].astype(str) @@ -4483,14 +4533,14 @@ def test_cmp_local_file(self) -> None: with self.create_datasets() as (actual, expected): indexers = {"i": [1, 0, 0], "j": [1, 2, 0, 1]} - assert_equal(actual.isel(**indexers), expected.isel(**indexers)) + assert_equal(actual.isel(indexers), expected.isel(indexers)) with self.create_datasets() as (actual, expected): indexers2 = { "i": DataArray([0, 1, 0], dims="a"), "j": DataArray([0, 2, 1], dims="a"), } - assert_equal(actual.isel(**indexers2), expected.isel(**indexers2)) + assert_equal(actual.isel(indexers2), expected.isel(indexers2)) def test_compatible_to_netcdf(self) -> None: # make sure it can be saved as a netcdf @@ -4535,6 +4585,21 @@ def test_session(self) -> None: ) +@network +@requires_scipy_or_netCDF4 +@requires_pydap +class TestPydapInstance(TestPydapOnline): + @contextlib.contextmanager + def create_datasets(self, **kwargs): + url = "http://test.opendap.org/opendap/hyrax/data/nc/bears.nc" + engine = PydapBackendEntrypoint() + actual = open_dataset(url, engine=engine, **kwargs) + with open_example_dataset("bears.nc") as expected: + # workaround to restore string which is converted to byte + expected["bears"] = expected["bears"].astype(str) + yield actual, expected + + @requires_scipy @requires_pynio class TestPyNio(CFEncodedBase, NetCDF3Only): @@ -5239,7 +5304,7 @@ def test_scipy_entrypoint(tmp_path: Path) -> None: assert entrypoint.guess_can_open("something-local.nc") assert entrypoint.guess_can_open("something-local.nc.gz") assert not entrypoint.guess_can_open("not-found-and-no-extension") - assert not entrypoint.guess_can_open(b"not-a-netcdf-file") # type: ignore[arg-type] + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") @requires_h5netcdf