Skip to content

Improving typing of xr.Dataset.__getitem__ #4125

Closed
@nbren12

Description

@nbren12

First, I'd like the thank the xarray dev's for adding type hints to this library, not many libraries have this feature!

That said, the indexing notation of xr.Dataset does not currently play well wit mypy since it returns a Union type. This results in a lot of mypy errors like this:

workflows/fine_res_budget/budget/budgets.py:284: error: Argument 6 to "compute_recoarsened_budget_field" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
workflows/fine_res_budget/budget/budgets.py:285: error: Argument 1 to "storage" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
workflows/fine_res_budget/budget/budgets.py:286: error: Argument "unresolved_flux" to "compute_recoarsened_budget_field" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"
workflows/fine_res_budget/budget/budgets.py:287: error: Argument "saturation_adjustment" to "compute_recoarsened_budget_field" has incompatible type "Union[DataArray, Dataset]"; expected "DataArray"

MCVE Code Sample

def func(ds: xr.Dataset):
    pass

dataset: xr.Dataset = ...

# error:
# this line will give type error because mypy doesn't know 
# if ds[['a', 'b]] is Dataset or a DataArray
func(ds[['a', 'b']])

Expected Output

Mypy should be able to infer that ds[['a', b']] is a Dataset, and that ds['a'] is a DataArray.

Problem Description

This requires any routine with type hints that consume an output of xr.Dataset.__getitem__ to require a Union[DataArray, Dataset] even if it really intends to be used with either DataArray or DataArray. Because ds[something] is a ubiquitous syntax, this behavior accounts for approximately 50% of mypy errors in my xarray heavy code.

Versions

Output of xr.show_versions()

In [1]: import xarray as xr
xr.
In [2]: xr.show_versions()

INSTALLED VERSIONS

commit: None
python: 3.7.7 (default, May 7 2020, 21:25:33)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-1020-gcp
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: C.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.4
libnetcdf: 4.7.3

xarray: 0.15.1
pandas: 1.0.1
numpy: 1.18.1
scipy: 1.4.1
netCDF4: 1.5.3
pydap: None
h5netcdf: 0.8.0
h5py: 2.10.0
Nio: None
zarr: 2.4.0
cftime: 1.1.2
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.17.2
distributed: 2.17.0
matplotlib: 3.1.3
cartopy: 0.17.0
seaborn: 0.10.1
numbagg: None
setuptools: 46.4.0.post20200518
pip: 20.0.2
conda: 4.8.3
pytest: 5.4.2
IPython: 7.13.0
sphinx: None

Potential solution

I think we can fix this with typing.overload. I am not too familiar with that librariy, but I think something like the following might work:

from typing import overload

class Dataset
    @overload
    def __getitem__(self, key: Hashable) -> DataArray: ...
    
     @overload
    def __getitem__(self, key: List[Hashable]) -> "Dataset": ...

     # actual implementation
    def __getitem__

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions