Description
First, I'd like the thank the xarray dev's for adding type hints to this library, not many libraries have this feature!
That said, the indexing notation of xr.Dataset
does not currently play well wit mypy since it returns a Union type. This results in a lot of mypy errors like this:
workflows/fine_res_budget/budget/budgets.py:284: error: Argument 6 to "compute_recoarsened_budget_field" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
workflows/fine_res_budget/budget/budgets.py:285: error: Argument 1 to "storage" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
workflows/fine_res_budget/budget/budgets.py:286: error: Argument "unresolved_flux" to "compute_recoarsened_budget_field" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
workflows/fine_res_budget/budget/budgets.py:287: error: Argument "saturation_adjustment" to "compute_recoarsened_budget_field" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
MCVE Code Sample
def func(ds: xr.Dataset):
pass
dataset: xr.Dataset = ...
# error:
# this line will give type error because mypy doesn't know
# if ds[['a', 'b]] is Dataset or a DataArray
func(ds[['a', 'b']])
Expected Output
Mypy should be able to infer that ds[['a', b']]
is a Dataset, and that ds['a']
is a DataArray.
Problem Description
This requires any routine with type hints that consume an output of xr.Dataset.__getitem__
to require a Union[DataArray, Dataset]
even if it really intends to be used with either DataArray
or DataArray
. Because ds[something]
is a ubiquitous syntax, this behavior accounts for approximately 50% of mypy errors in my xarray heavy code.
Versions
Output of xr.show_versions()
In [1]: import xarray as xr
xr.
In [2]: xr.show_versions()
INSTALLED VERSIONS
commit: None
python: 3.7.7 (default, May 7 2020, 21:25:33)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-1020-gcp
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: C.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.4
libnetcdf: 4.7.3
xarray: 0.15.1
pandas: 1.0.1
numpy: 1.18.1
scipy: 1.4.1
netCDF4: 1.5.3
pydap: None
h5netcdf: 0.8.0
h5py: 2.10.0
Nio: None
zarr: 2.4.0
cftime: 1.1.2
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.17.2
distributed: 2.17.0
matplotlib: 3.1.3
cartopy: 0.17.0
seaborn: 0.10.1
numbagg: None
setuptools: 46.4.0.post20200518
pip: 20.0.2
conda: 4.8.3
pytest: 5.4.2
IPython: 7.13.0
sphinx: None
Potential solution
I think we can fix this with typing.overload. I am not too familiar with that librariy, but I think something like the following might work:
from typing import overload
class Dataset
@overload
def __getitem__(self, key: Hashable) -> DataArray: ...
@overload
def __getitem__(self, key: List[Hashable]) -> "Dataset": ...
# actual implementation
def __getitem__