Skip to content

Commit 72ea77f

Browse files
committed
fix some mypy errors
1 parent 7d598fb commit 72ea77f

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

xarray/core/_reductions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from __future__ import annotations
55

6-
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence
6+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence, TypeVar
77

88
from . import duck_array_ops
99
from .options import OPTIONS
@@ -12,7 +12,9 @@
1212
if TYPE_CHECKING:
1313
from .dataarray import DataArray
1414
from .dataset import Dataset
15-
from .types import T_DataArray, T_Dataset
15+
16+
T_Dataset = TypeVar("T_Dataset", bound="DatasetReductions")
17+
T_DataArray = TypeVar("T_DataArray", bound="DataArrayReductions")
1618

1719
try:
1820
import flox
@@ -24,15 +26,15 @@ class DatasetReductions:
2426
__slots__ = ()
2527

2628
def reduce(
27-
self,
29+
self: T_Dataset,
2830
func: Callable[..., Any],
2931
dim: None | Hashable | Sequence[Hashable] = None,
3032
*,
3133
axis: None | int | Sequence[int] = None,
3234
keep_attrs: bool | None = None,
3335
keepdims: bool = False,
3436
**kwargs: Any,
35-
) -> Dataset:
37+
) -> T_Dataset:
3638
raise NotImplementedError()
3739

3840
def count(
@@ -1034,15 +1036,15 @@ class DataArrayReductions:
10341036
__slots__ = ()
10351037

10361038
def reduce(
1037-
self,
1039+
self: T_DataArray,
10381040
func: Callable[..., Any],
10391041
dim: None | Hashable | Sequence[Hashable] = None,
10401042
*,
10411043
axis: None | int | Sequence[int] = None,
10421044
keep_attrs: bool | None = None,
10431045
keepdims: bool = False,
10441046
**kwargs: Any,
1045-
) -> DataArray:
1047+
) -> T_DataArray:
10461048
raise NotImplementedError()
10471049

10481050
def count(

xarray/core/rolling.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Hashable,
1313
Iterator,
1414
Mapping,
15+
Sequence,
1516
TypeVar,
1617
)
1718

@@ -468,9 +469,8 @@ def reduce(
468469
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
469470
)
470471

471-
result = windows.reduce(
472-
func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
473-
)
472+
dim: Sequence[Hashable] = list(rolling_dim.values())
473+
result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs)
474474

475475
# Find valid windows based on count.
476476
counts = self._counts(keep_attrs=False)
@@ -487,14 +487,15 @@ def _counts(self, keep_attrs: bool | None) -> DataArray:
487487
# array is faster to be reduced than object array.
488488
# The use of skipna==False is also faster since it does not need to
489489
# copy the strided array.
490+
dim: Sequence[Hashable] = list(rolling_dim.values())
490491
counts = (
491492
self.obj.notnull(keep_attrs=keep_attrs)
492493
.rolling(
493494
{d: w for d, w in zip(self.dim, self.window)},
494495
center={d: self.center[i] for i, d in enumerate(self.dim)},
495496
)
496497
.construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
497-
.sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs)
498+
.sum(dim=dim, skipna=False, keep_attrs=keep_attrs)
498499
)
499500
return counts
500501

xarray/util/generate_reductions.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
from __future__ import annotations
2424
25-
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence
25+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence, TypeVar
2626
2727
from . import duck_array_ops
2828
from .options import OPTIONS
@@ -31,7 +31,9 @@
3131
if TYPE_CHECKING:
3232
from .dataarray import DataArray
3333
from .dataset import Dataset
34-
from .types import T_DataArray, T_Dataset
34+
35+
T_Dataset = TypeVar("T_Dataset", bound="DatasetReductions")
36+
T_DataArray = TypeVar("T_DataArray", bound="DataArrayReductions")
3537
3638
try:
3739
import flox
@@ -44,15 +46,15 @@ class {obj}{cls}Reductions:
4446
__slots__ = ()
4547
4648
def reduce(
47-
self,
49+
self{self_type_snippet},
4850
func: Callable[..., Any],
4951
dim: None | Hashable | Sequence[Hashable] = None,
5052
*,
5153
axis: None | int | Sequence[int] = None,
5254
keep_attrs: bool | None = None,
5355
keepdims: bool = False,
5456
**kwargs: Any,
55-
) -> {obj}:
57+
) -> {return_type}:
5658
raise NotImplementedError()"""
5759

5860
GROUPBY_PREAMBLE = """
@@ -246,7 +248,13 @@ def __init__(
246248
self.docref = docref
247249
self.docref_description = docref_description
248250
self.example_call_preamble = example_call_preamble
249-
self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls)
251+
self.common_kwargs = dict(
252+
obj=self.datastructure.name,
253+
cls=cls,
254+
self_type_snippet=": " + self.self_type if self.self_type else "",
255+
return_type=self.self_type if self.self_type else self.datastructure.name,
256+
)
257+
self.preamble = definition_preamble.format(**self.common_kwargs)
250258
if not see_also_obj:
251259
self.see_also_obj = self.datastructure.name
252260
else:
@@ -258,12 +266,7 @@ def generate_methods(self):
258266
yield self.generate_method(method)
259267

260268
def generate_method(self, method):
261-
template_kwargs = dict(
262-
obj=self.datastructure.name,
263-
method=method.name,
264-
self_type_snippet=": " + self.self_type if self.self_type else "",
265-
return_type=self.self_type if self.self_type else self.datastructure.name,
266-
)
269+
template_kwargs = dict(self.common_kwargs, method=method.name)
267270

268271
if method.extra_kwargs:
269272
extra_kwargs = "\n " + "\n ".join(

0 commit comments

Comments
 (0)