Skip to content

Fix typing of backends #7114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore
from io import BufferedIOBase

from ..core.types import (
CombineAttrsOptions,
CompatOptions,
Expand Down Expand Up @@ -366,7 +368,7 @@ def _dataset_from_backend_dataset(


def open_dataset(
filename_or_obj: str | os.PathLike | AbstractDataStore,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
Expand Down Expand Up @@ -550,7 +552,7 @@ def open_dataset(


def open_dataarray(
filename_or_obj: str | os.PathLike,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
Expand Down
16 changes: 12 additions & 4 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import time
import traceback
from typing import Any
from typing import TYPE_CHECKING, Any, ClassVar, Iterable

import numpy as np

Expand All @@ -13,6 +13,9 @@
from ..core.pycompat import is_duck_dask_array
from ..core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri

if TYPE_CHECKING:
from io import BufferedIOBase

# Create a logger object, but don't add any handlers. Leave that to user code.
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -371,13 +374,15 @@ class BackendEntrypoint:
method is not mandatory.
"""

available: ClassVar[bool] = True

open_dataset_parameters: tuple | None = None
"""list of ``open_dataset`` method parameters"""

def open_dataset(
self,
filename_or_obj: str | os.PathLike,
drop_variables: tuple[str] | None = None,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
drop_variables: str | Iterable[str] | None = None,
**kwargs: Any,
):
"""
Expand All @@ -386,7 +391,10 @@ def open_dataset(

raise NotImplementedError

def guess_can_open(self, filename_or_obj: str | os.PathLike):
def guess_can_open(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
):
"""
Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`.
"""
Expand Down
17 changes: 13 additions & 4 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
import sys
import warnings
from importlib.metadata import entry_points
from typing import TYPE_CHECKING, Any

from .common import BACKEND_ENTRYPOINTS, BackendEntrypoint

if TYPE_CHECKING:
import os
from io import BufferedIOBase

from .common import AbstractDataStore

STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"]


Expand Down Expand Up @@ -83,7 +90,7 @@ def sort_backends(backend_entrypoints):
return ordered_backends_entrypoints


def build_engines(entrypoints):
def build_engines(entrypoints) -> dict[str, BackendEntrypoint]:
backend_entrypoints = {}
for backend_name, backend in BACKEND_ENTRYPOINTS.items():
if backend.available:
Expand All @@ -97,7 +104,7 @@ def build_engines(entrypoints):


@functools.lru_cache(maxsize=1)
def list_engines():
def list_engines() -> dict[str, BackendEntrypoint]:
# New selection mechanism introduced with Python 3.10. See GH6514.
if sys.version_info >= (3, 10):
entrypoints = entry_points(group="xarray.backends")
Expand All @@ -106,7 +113,9 @@ def list_engines():
return build_engines(entrypoints)


def guess_engine(store_spec):
def guess_engine(
store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
):
engines = list_engines()

for engine, backend in engines.items():
Expand Down Expand Up @@ -155,7 +164,7 @@ def guess_engine(store_spec):
raise ValueError(error_msg)


def get_backend(engine):
def get_backend(engine: str | type[BackendEntrypoint]) -> BackendEntrypoint:
"""Select open_dataset method based on current engine."""
if isinstance(engine, str):
engines = list_engines()
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,7 +2167,7 @@ def chunk(
token: str | None = None,
lock: bool = False,
inline_array: bool = False,
**chunks_kwargs: Any,
**chunks_kwargs: None | int | str | tuple[int, ...],
) -> T_Dataset:
"""Coerce all arrays in this dataset into dask arrays with the given
chunks.
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def assert_allclose(a, b, check_default_indexes=True, **kwargs):
xarray.testing._assert_internal_invariants(b, check_default_indexes)


def create_test_data(seed=None, add_attrs=True):
def create_test_data(seed: int | None = None, add_attrs: bool = True) -> Dataset:
rs = np.random.RandomState(seed)
_vars = {
"var1": ["dim1", "dim2"],
Expand Down
Loading