Skip to content

Fix some mypy issues #6531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ jobs:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
env:
[
env: [
# Minimum python version:
"py38-bare-minimum",
"py38-min-all-deps",
Expand Down Expand Up @@ -71,8 +70,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/conda_pkgs_dir
key:
${{ runner.os }}-conda-${{ matrix.env }}-${{
key: ${{ runner.os }}-conda-${{ matrix.env }}-${{
hashFiles('ci/requirements/**.yml') }}

- uses: conda-incubator/setup-miniconda@v2
Expand Down Expand Up @@ -152,6 +150,41 @@ jobs:
run: |
python -m pytest --doctest-modules xarray --ignore xarray/tests

mypy:
name: Mypy
runs-on: "ubuntu-latest"
if: needs.detect-ci-trigger.outputs.triggered == 'false'
defaults:
run:
shell: bash -l {0}

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0 # Fetch all history for all branches and tags.
- uses: conda-incubator/setup-miniconda@v2
with:
channels: conda-forge
channel-priority: strict
mamba-version: "*"
activate-environment: xarray-tests
auto-update-conda: false
python-version: "3.9"

- name: Install conda dependencies
run: |
mamba env update -f ci/requirements/environment.yml
- name: Install xarray
run: |
python -m pip install --no-deps -e .
- name: Version info
run: |
conda info -a
conda list
python xarray/util/print_versions.py
Comment on lines +156 to +184
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a lot of boilerplate. Am I doing it correctly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially we could turn this setup into an Action; we did this in prql quite successfully: https://github.com/prql/prql/blob/main/.github/actions/cargo-test/action.yaml

- name: Run mypy
run: mypy

min-version-policy:
name: Minimum Version Policy
runs-on: "ubuntu-latest"
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,8 @@ def chunk(
chunks = {}

if isinstance(chunks, (float, str, int)):
chunks = dict.fromkeys(self.dims, chunks)
# ignoring type; unclear why it won't accept a Literal into the value.
chunks = dict.fromkeys(self.dims, chunks) # type: ignore
elif isinstance(chunks, (tuple, list)):
chunks = dict(zip(self.dims, chunks))
else:
Expand Down Expand Up @@ -4735,7 +4736,7 @@ def curvefit(

def drop_duplicates(
self,
dim: Hashable | Iterable[Hashable] | ...,
dim: Hashable | Iterable[Hashable],
keep: Literal["first", "last"] | Literal[False] = "first",
):
"""Returns a new DataArray with duplicate dimension values removed.
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7981,7 +7981,7 @@ def _wrapper(Y, *coords_, **kwargs):

def drop_duplicates(
self,
dim: Hashable | Iterable[Hashable] | ...,
dim: Hashable | Iterable[Hashable],
keep: Literal["first", "last"] | Literal[False] = "first",
):
"""Returns a new Dataset with duplicate dimension values removed.
Expand All @@ -8005,9 +8005,11 @@ def drop_duplicates(
DataArray.drop_duplicates
"""
if isinstance(dim, str):
dims = (dim,)
dims: Iterable = (dim,)
elif dim is ...:
dims = self.dims
elif not isinstance(dim, Iterable):
dims = [dim]
else:
dims = dim

Expand Down
76 changes: 33 additions & 43 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import enum
import functools
import operator
Expand All @@ -6,19 +8,7 @@
from dataclasses import dataclass, field
from datetime import timedelta
from html import escape
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Mapping

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -59,12 +49,12 @@ class IndexSelResult:

"""

dim_indexers: Dict[Any, Any]
indexes: Dict[Any, "Index"] = field(default_factory=dict)
variables: Dict[Any, "Variable"] = field(default_factory=dict)
drop_coords: List[Hashable] = field(default_factory=list)
drop_indexes: List[Hashable] = field(default_factory=list)
rename_dims: Dict[Any, Hashable] = field(default_factory=dict)
dim_indexers: dict[Any, Any]
indexes: dict[Any, Index] = field(default_factory=dict)
variables: dict[Any, Variable] = field(default_factory=dict)
drop_coords: list[Hashable] = field(default_factory=list)
drop_indexes: list[Hashable] = field(default_factory=list)
rename_dims: dict[Any, Hashable] = field(default_factory=dict)

def as_tuple(self):
"""Unlike ``dataclasses.astuple``, return a shallow copy.
Expand All @@ -82,7 +72,7 @@ def as_tuple(self):
)


def merge_sel_results(results: List[IndexSelResult]) -> IndexSelResult:
def merge_sel_results(results: list[IndexSelResult]) -> IndexSelResult:
all_dims_count = Counter([dim for res in results for dim in res.dim_indexers])
duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1}

Expand Down Expand Up @@ -124,13 +114,13 @@ def group_indexers_by_index(
obj: T_Xarray,
indexers: Mapping[Any, Any],
options: Mapping[str, Any],
) -> List[Tuple["Index", Dict[Any, Any]]]:
) -> list[tuple[Index, dict[Any, Any]]]:
"""Returns a list of unique indexes and their corresponding indexers."""
unique_indexes = {}
grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict)
grouped_indexers: Mapping[int | None, dict] = defaultdict(dict)

for key, label in indexers.items():
index: "Index" = obj.xindexes.get(key, None)
index: Index = obj.xindexes.get(key, None)

if index is not None:
index_id = id(index)
Expand Down Expand Up @@ -787,7 +777,7 @@ class IndexingSupport(enum.Enum):

def explicit_indexing_adapter(
key: ExplicitIndexer,
shape: Tuple[int, ...],
shape: tuple[int, ...],
indexing_support: IndexingSupport,
raw_indexing_method: Callable,
) -> Any:
Expand Down Expand Up @@ -821,8 +811,8 @@ def explicit_indexing_adapter(


def decompose_indexer(
indexer: ExplicitIndexer, shape: Tuple[int, ...], indexing_support: IndexingSupport
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
if isinstance(indexer, VectorizedIndexer):
return _decompose_vectorized_indexer(indexer, shape, indexing_support)
if isinstance(indexer, (BasicIndexer, OuterIndexer)):
Expand All @@ -848,9 +838,9 @@ def _decompose_slice(key, size):

def _decompose_vectorized_indexer(
indexer: VectorizedIndexer,
shape: Tuple[int, ...],
shape: tuple[int, ...],
indexing_support: IndexingSupport,
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
"""
Decompose vectorized indexer to the successive two indexers, where the
first indexer will be used to index backend arrays, while the second one
Expand Down Expand Up @@ -929,10 +919,10 @@ def _decompose_vectorized_indexer(


def _decompose_outer_indexer(
indexer: Union[BasicIndexer, OuterIndexer],
shape: Tuple[int, ...],
indexer: BasicIndexer | OuterIndexer,
shape: tuple[int, ...],
indexing_support: IndexingSupport,
) -> Tuple[ExplicitIndexer, ExplicitIndexer]:
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
"""
Decompose outer indexer to the successive two indexers, where the
first indexer will be used to index backend arrays, while the second one
Expand Down Expand Up @@ -973,7 +963,7 @@ def _decompose_outer_indexer(
return indexer, BasicIndexer(())
assert isinstance(indexer, (OuterIndexer, BasicIndexer))

backend_indexer: List[Any] = []
backend_indexer: list[Any] = []
np_indexer = []
# make indexer positive
pos_indexer: list[np.ndarray | int | np.number] = []
Expand Down Expand Up @@ -1395,7 +1385,7 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
return np.asarray(array.values, dtype=dtype)

@property
def shape(self) -> Tuple[int]:
def shape(self) -> tuple[int]:
return (len(self.array),)

def _convert_scalar(self, item):
Expand All @@ -1420,13 +1410,13 @@ def _convert_scalar(self, item):

def __getitem__(
self, indexer
) -> Union[
"PandasIndexingAdapter",
NumpyIndexingAdapter,
np.ndarray,
np.datetime64,
np.timedelta64,
]:
) -> (
PandasIndexingAdapter
| NumpyIndexingAdapter
| np.ndarray
| np.datetime64
| np.timedelta64
):
key = indexer.tuple
if isinstance(key, tuple) and len(key) == 1:
# unpack key so it can index a pandas.Index object (pandas.Index
Expand All @@ -1449,7 +1439,7 @@ def transpose(self, order) -> pd.Index:
def __repr__(self) -> str:
return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})"

def copy(self, deep: bool = True) -> "PandasIndexingAdapter":
def copy(self, deep: bool = True) -> PandasIndexingAdapter:
# Not the same as just writing `self.array.copy(deep=deep)`, as
# shallow copies of the underlying numpy.ndarrays become deep ones
# upon pickling
Expand All @@ -1476,7 +1466,7 @@ def __init__(
self,
array: pd.MultiIndex,
dtype: DTypeLike = None,
level: Optional[str] = None,
level: str | None = None,
):
super().__init__(array, dtype)
self.level = level
Expand Down Expand Up @@ -1535,7 +1525,7 @@ def _repr_html_(self) -> str:
array_repr = short_numpy_repr(self._get_array_subset())
return f"<pre>{escape(array_repr)}</pre>"

def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter":
def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter:
# see PandasIndexingAdapter.copy
array = self.array.copy(deep=True) if deep else self.array
return type(self)(array, self._dtype, self.level)
3 changes: 2 additions & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def remove_incompatible_items(
del first_dict[k]


def is_dict_like(value: Any) -> bool:
# It's probably OK to give this as a TypeGuard; though it's not perfectly robust.
def is_dict_like(value: Any) -> TypeGuard[dict]:
return hasattr(value, "keys") and hasattr(value, "__getitem__")


Expand Down
10 changes: 4 additions & 6 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,20 +1007,18 @@ def test_decode_ambiguous_time_warns(calendar) -> None:
units = "days since 1-1-1"
expected = num2date(dates, units, calendar=calendar, only_use_cftime_datetimes=True)

exp_warn_type = SerializationWarning if is_standard_calendar else None

with pytest.warns(exp_warn_type) as record:
result = decode_cf_datetime(dates, units, calendar=calendar)

if is_standard_calendar:
with pytest.warns(SerializationWarning) as record:
result = decode_cf_datetime(dates, units, calendar=calendar)
relevant_warnings = [
r
for r in record.list
if str(r.message).startswith("Ambiguous reference date string: 1-1-1")
]
assert len(relevant_warnings) == 1
else:
assert not record
with assert_no_warnings():
result = decode_cf_datetime(dates, units, calendar=calendar)

np.testing.assert_array_equal(result, expected)

Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,10 @@ def test_short_numpy_repr() -> None:
assert num_lines < 30

# threshold option (default: 200)
array = np.arange(100)
assert "..." not in formatting.short_numpy_repr(array)
array2 = np.arange(100)
assert "..." not in formatting.short_numpy_repr(array2)
with xr.set_options(display_values_threshold=10):
assert "..." in formatting.short_numpy_repr(array)
assert "..." in formatting.short_numpy_repr(array2)


def test_large_array_repr_length() -> None:
Expand Down