From 34e3ce03e15bcea383fecb8fcd0dce0a496614d3 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Wed, 10 Jun 2020 16:29:34 -0700 Subject: [PATCH 1/6] Improve typehints of xr.Dataset.__getitem__ Resolves #4125 --- xarray/core/dataset.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a8011afd3e3..2dd4104aab7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -27,6 +27,7 @@ TypeVar, Union, cast, + overload, ) import numpy as np @@ -1240,8 +1241,15 @@ def loc(self) -> _LocIndexer: and only when the key is a dict of the form {dim: labels}. """ return _LocIndexer(self) + - def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]": + @overload + def __getitem__(self, key: Hashable) -> DataArray: ... + + @overload + def __getitem__(self, key: Any) -> Dataset: ... + + def __getitem__(self, key): """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. From 414ee30762465b7a09059a5ec042b75b99d1b03e Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Wed, 10 Jun 2020 16:44:16 -0700 Subject: [PATCH 2/6] Add overload for Mapping behavior Sadly this is not working with my version of mypy. See https://github.com/python/mypy/issues/7328 --- xarray/core/dataset.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2dd4104aab7..57fdc726b68 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1241,13 +1241,18 @@ def loc(self) -> _LocIndexer: and only when the key is a dict of the form {dim: labels}. """ return _LocIndexer(self) - @overload - def __getitem__(self, key: Hashable) -> DataArray: ... + def __getitem__(self, key: Hashable) -> DataArray: + ... @overload - def __getitem__(self, key: Any) -> Dataset: ... + def __getitem__(self, key: Mapping) -> "Dataset": + ... + + @overload + def __getitem__(self, key: List) -> "Dataset": + ... def __getitem__(self, key): """Access variables or coordinates this dataset as a From dfb3d2a5eabe4ed5e6cf5296246960e971b484f3 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Thu, 11 Jun 2020 22:18:34 -0700 Subject: [PATCH 3/6] Overload only Hashable inputs Given mypy's use of overloads, I think this is all we can do. If the argument is not Hashable, then return the Union type as before. --- xarray/core/dataset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 57fdc726b68..32647ec46cc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1247,11 +1247,7 @@ def __getitem__(self, key: Hashable) -> DataArray: ... @overload - def __getitem__(self, key: Mapping) -> "Dataset": - ... - - @overload - def __getitem__(self, key: List) -> "Dataset": + def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]": ... def __getitem__(self, key): From 1f26517aafbec3ecd5fd5d98d77c011538434cef Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Thu, 11 Jun 2020 22:30:50 -0700 Subject: [PATCH 4/6] Lint --- xarray/core/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 32647ec46cc..ef3d8c0574e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1242,15 +1242,15 @@ def loc(self) -> _LocIndexer: """ return _LocIndexer(self) - @overload + @overload # noqa: F811 def __getitem__(self, key: Hashable) -> DataArray: ... - @overload + @overload # noqa: F811 def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]": ... - def __getitem__(self, key): + def __getitem__(self, key): # noqa: F811 """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. From 2f6ca7eee0d29a58d4fca5cdb8097dd7ae839bd2 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Thu, 11 Jun 2020 22:41:41 -0700 Subject: [PATCH 5/6] Quote the DataArray to avoid error in py3.6 --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ef3d8c0574e..dc2d4091ee4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1243,7 +1243,7 @@ def loc(self) -> _LocIndexer: return _LocIndexer(self) @overload # noqa: F811 - def __getitem__(self, key: Hashable) -> DataArray: + def __getitem__(self, key: Hashable) -> "DataArray": ... @overload # noqa: F811 From 400dc62b0627b2deae4422d4579c1762cfcf4274 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 13 Jun 2020 18:43:31 +0100 Subject: [PATCH 6/6] Code review --- .pre-commit-config.yaml | 2 +- ci/requirements/py38.yml | 2 +- xarray/core/dataset.py | 16 ++++++++++------ xarray/core/weighted.py | 6 +++--- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26bf4803ef6..1d384e58a3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.761 # Must match ci/requirements/*.yml + rev: v0.780 # Must match ci/requirements/*.yml hooks: - id: mypy # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 diff --git a/ci/requirements/py38.yml b/ci/requirements/py38.yml index 24602f884e9..7dff3a1bd97 100644 --- a/ci/requirements/py38.yml +++ b/ci/requirements/py38.yml @@ -22,7 +22,7 @@ dependencies: - isort - lxml # Optional dep of pydap - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml + - mypy=0.780 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dc2d4091ee4..17f1f670b09 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1242,21 +1242,25 @@ def loc(self) -> _LocIndexer: """ return _LocIndexer(self) - @overload # noqa: F811 - def __getitem__(self, key: Hashable) -> "DataArray": + # FIXME https://github.com/python/mypy/issues/7328 + @overload + def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore + ... + + @overload + def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore ... - @overload # noqa: F811 - def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]": + @overload + def __getitem__(self, key: Any) -> "Dataset": ... - def __getitem__(self, key): # noqa: F811 + def __getitem__(self, key): """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. Indexing with a list of names will return a new ``Dataset`` object. """ - # TODO(shoyer): type this properly: https://github.com/python/mypy/issues/7328 if utils.is_dict_like(key): return self.isel(**cast(Mapping, key)) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 21ed06ea85f..fa143342c06 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -72,11 +72,11 @@ class Weighted: def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... - @overload # noqa: F811 - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811 + @overload + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: ... - def __init__(self, obj, weights): # noqa: F811 + def __init__(self, obj, weights): """ Create a Weighted object