diff --git a/xarray/core/common.py b/xarray/core/common.py index 6ec07156160..b518e8431fd 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,6 +1,8 @@ from collections import OrderedDict from contextlib import suppress from textwrap import dedent +from typing import (Any, Callable, Hashable, Iterable, Iterator, List, Mapping, + MutableMapping, Optional, Tuple, TypeVar, Union) import numpy as np import pandas as pd @@ -11,13 +13,18 @@ from .pycompat import dask_array_type from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs + # Used as a sentinel value to indicate a all dimensions ALL_DIMS = ReprObject('') -class ImplementsArrayReduce(object): +T = TypeVar('T') + + +class ImplementsArrayReduce: @classmethod - def _reduce_method(cls, func, include_skipna, numeric_only): + def _reduce_method(cls, func: Callable, include_skipna: bool, + numeric_only: bool): if include_skipna: def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs): @@ -46,9 +53,10 @@ def wrapped_func(self, dim=None, axis=None, # type: ignore and 'axis' arguments can be supplied.""") -class ImplementsDatasetReduce(object): +class ImplementsDatasetReduce: @classmethod - def _reduce_method(cls, func, include_skipna, numeric_only): + def _reduce_method(cls, func: Callable, include_skipna: bool, + numeric_only: bool): if include_skipna: def wrapped_func(self, dim=None, skipna=None, **kwargs): @@ -76,46 +84,38 @@ def wrapped_func(self, dim=None, **kwargs): # type: ignore class AbstractArray(ImplementsArrayReduce): - """Shared base class for DataArray and Variable.""" - - def __bool__(self): + """Shared base class for DataArray and Variable. + """ + def __bool__(self: Any) -> bool: return bool(self.values) - # Python 3 uses __bool__, Python 2 uses __nonzero__ - __nonzero__ = __bool__ - - def __float__(self): + def __float__(self: Any) -> float: return float(self.values) - def __int__(self): + def __int__(self: Any) -> int: return int(self.values) - def __complex__(self): + def __complex__(self: Any) -> complex: return complex(self.values) - def __long__(self): - return long(self.values) # noqa - - def __array__(self, dtype=None): + def __array__(self: Any, dtype: Union[str, np.dtype, None] = None + ) -> np.ndarray: return np.asarray(self.values, dtype=dtype) - def __repr__(self): + def __repr__(self) -> str: return formatting.array_repr(self) - def _iter(self): + def _iter(self: Any) -> Iterator[Any]: for n in range(len(self)): yield self[n] - def __iter__(self): + def __iter__(self: Any) -> Iterator[Any]: if self.ndim == 0: raise TypeError('iteration over a 0-d array') return self._iter() - @property - def T(self): - return self.transpose() - - def get_axis_num(self, dim): + def get_axis_num(self, dim: Union[Hashable, Iterable[Hashable]] + ) -> Union[int, Tuple[int, ...]]: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -128,12 +128,12 @@ def get_axis_num(self, dim): int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if isinstance(dim, str): - return self._get_axis_num(dim) - else: + if isinstance(dim, Iterable) and not isinstance(dim, str): return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) - def _get_axis_num(self, dim): + def _get_axis_num(self: Any, dim: Hashable) -> int: try: return self.dims.index(dim) except ValueError: @@ -141,7 +141,7 @@ def _get_axis_num(self, dim): (dim, self.dims)) @property - def sizes(self): + def sizes(self: Any) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -153,7 +153,7 @@ def sizes(self): return Frozen(OrderedDict(zip(self.dims, self.shape))) -class AttrAccessMixin(object): +class AttrAccessMixin: """Mixin class that allows getting keys with attribute access """ _initialized = False @@ -168,7 +168,7 @@ def _item_sources(self): """List of places to look-up items for key-autocompletion """ return [] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name != '__setstate__': # this avoids an infinite loop when pickle looks for the # __setstate__ attribute before the xarray object is initialized @@ -178,7 +178,7 @@ def __getattr__(self, name): raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name)) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: if self._initialized: try: # Allow setting instance variables if they already exist @@ -192,7 +192,7 @@ def __setattr__(self, name, value): "assign variables." % (name, type(self).__name__)) object.__setattr__(self, name, value) - def __dir__(self): + def __dir__(self) -> List[str]: """Provide method name lookup and completion. Only provide 'public' methods. """ @@ -202,7 +202,7 @@ def __dir__(self): if isinstance(item, str)] return sorted(set(dir(type(self)) + extra_attrs)) - def _ipython_key_completions_(self): + def _ipython_key_completions_(self) -> List[str]: """Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. @@ -214,49 +214,57 @@ def _ipython_key_completions_(self): return list(set(item_lists)) -def get_squeeze_dims(xarray_obj, dim, axis=None): +def get_squeeze_dims(xarray_obj, + dim: Union[Hashable, Iterable[Hashable], None] = None, + axis: Union[int, Iterable[int], None] = None + ) -> List[Hashable]: """Get a list of dimensions to squeeze out. """ if dim is not None and axis is not None: raise ValueError('cannot use both parameters `axis` and `dim`') - if dim is None and axis is None: - dim = [d for d, s in xarray_obj.sizes.items() if s == 1] + return [d for d, s in xarray_obj.sizes.items() if s == 1] + + if isinstance(dim, Iterable) and not isinstance(dim, str): + dim = list(dim) + elif dim is not None: + dim = [dim] else: - if isinstance(dim, str): - dim = [dim] + assert axis is not None if isinstance(axis, int): - axis = (axis, ) - if isinstance(axis, tuple): - for a in axis: - if not isinstance(a, int): - raise ValueError( - 'parameter `axis` must be int or tuple of int.') - alldims = list(xarray_obj.sizes.keys()) - dim = [alldims[a] for a in axis] - if any(xarray_obj.sizes[k] > 1 for k in dim): - raise ValueError('cannot select a dimension to squeeze out ' - 'which has length greater than one') + axis = [axis] + axis = list(axis) + if any(not isinstance(a, int) for a in axis): + raise TypeError( + 'parameter `axis` must be int or iterable of int.') + alldims = list(xarray_obj.sizes.keys()) + dim = [alldims[a] for a in axis] + + if any(xarray_obj.sizes[k] > 1 for k in dim): + raise ValueError('cannot select a dimension to squeeze out ' + 'which has length greater than one') return dim class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - def squeeze(self, dim=None, drop=False, axis=None): + def squeeze(self, dim: Union[Hashable, Iterable[Hashable], None] = None, + drop: bool = False, + axis: Union[int, Iterable[int], None] = None): """Return a new object with squeezed data. Parameters ---------- - dim : None or str or tuple of str, optional + dim : None or Hashable or iterable of Hashable, optional Selects a subset of the length one dimensions. If a dimension is selected with length greater than one, an error is raised. If None, all length one dimensions are squeezed. drop : bool, optional If ``drop=True``, drop squeezed coordinates instead of making them scalar. - axis : int, optional - Select the dimension to squeeze. Added for compatibility reasons. + axis : None or int or iterable of int, optional + Like dim, but positional. Returns ------- @@ -271,7 +279,7 @@ def squeeze(self, dim=None, drop=False, axis=None): dims = get_squeeze_dims(self, dim, axis) return self.isel(drop=drop, **{d: 0 for d in dims}) - def get_index(self, key): + def get_index(self, key: Hashable) -> pd.Index: """Get an index for a dimension, with fall-back to a default RangeIndex """ if key not in self.dims: @@ -283,8 +291,9 @@ def get_index(self, key): # need to ensure dtype=int64 in case range is empty on Python 2 return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64) - def _calc_assign_results(self, kwargs): - results = SortedKeysDict() + def _calc_assign_results(self, kwargs: Mapping[str, T] + ) -> MutableMapping[str, T]: + results = SortedKeysDict() # type: SortedKeysDict[str, T] for k, v in kwargs.items(): if callable(v): results[k] = v(self) @@ -372,7 +381,8 @@ def assign_attrs(self, *args, **kwargs): out.attrs.update(*args, **kwargs) return out - def pipe(self, func, *args, **kwargs): + def pipe(self, func: Union[Callable[..., T], Tuple[Callable[..., T], str]], + *args, **kwargs) -> T: """ Apply func(self, *args, **kwargs) @@ -424,15 +434,14 @@ def pipe(self, func, *args, **kwargs): if isinstance(func, tuple): func, target = func if target in kwargs: - msg = ('%s is both the pipe target and a keyword argument' - % target) - raise ValueError(msg) + raise ValueError('%s is both the pipe target and a keyword ' + 'argument' % target) kwargs[target] = self return func(*args, **kwargs) else: return func(self, *args, **kwargs) - def groupby(self, group, squeeze=True): + def groupby(self, group, squeeze: bool = True): """Returns a GroupBy object for performing grouped operations. Parameters @@ -478,8 +487,9 @@ def groupby(self, group, squeeze=True): """ # noqa return self._groupby_cls(self, group, squeeze=squeeze) - def groupby_bins(self, group, bins, right=True, labels=None, precision=3, - include_lowest=False, squeeze=True): + def groupby_bins(self, group, bins, right: bool = True, labels=None, + precision: int = 3, include_lowest: bool = False, + squeeze: bool = True): """Returns a GroupBy object for performing grouped operations. Rather than using all unique values of `group`, the values are discretized @@ -530,7 +540,9 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3, 'precision': precision, 'include_lowest': include_lowest}) - def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs): + def rolling(self, dim: Optional[Mapping[Hashable, int]] = None, + min_periods: Optional[int] = None, center: bool = False, + **dim_kwargs: int): """ Rolling window object. @@ -590,8 +602,11 @@ def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs): return self._rolling_cls(self, dim, min_periods=min_periods, center=center) - def coarsen(self, dim=None, boundary='exact', side='left', - coord_func='mean', **dim_kwargs): + def coarsen(self, dim: Optional[Mapping[Hashable, int]] = None, + boundary: str = 'exact', + side: Union[str, Mapping[Hashable, str]] = 'left', + coord_func: str = 'mean', + **dim_kwargs: int): """ Coarsen object. @@ -650,8 +665,12 @@ def coarsen(self, dim=None, boundary='exact', side='left', self, dim, boundary=boundary, side=side, coord_func=coord_func) - def resample(self, indexer=None, skipna=None, closed=None, label=None, - base=0, keep_attrs=None, loffset=None, **indexer_kwargs): + def resample(self, indexer: Optional[Mapping[Hashable, str]] = None, + skipna=None, closed: Optional[str] = None, + label: Optional[str] = None, + base: int = 0, keep_attrs: Optional[bool] = None, + loffset=None, + **indexer_kwargs: str): """Returns a Resample object for performing resampling operations. Handles both downsampling and upsampling. If any intervals contain no @@ -745,12 +764,11 @@ def resample(self, indexer=None, skipna=None, closed=None, label=None, "objects, e.g., data.resample(time='1D').mean()") indexer = either_dict_or_kwargs(indexer, indexer_kwargs, 'resample') - if len(indexer) != 1: raise ValueError( "Resampling only supported along single dimensions." ) - dim, freq = indexer.popitem() + dim, freq = next(iter(indexer.items())) dim_name = dim dim_coord = self[dim] @@ -772,7 +790,7 @@ def resample(self, indexer=None, skipna=None, closed=None, label=None, return resampler - def where(self, cond, other=dtypes.NA, drop=False): + def where(self, cond, other=dtypes.NA, drop: bool = False): """Filter elements from this object according to a condition. This operation follows the normal broadcasting and alignment rules that @@ -858,7 +876,7 @@ def where(self, cond, other=dtypes.NA, drop=False): return ops.where_method(self, cond, other) - def close(self): + def close(self: Any) -> None: """Close any files linked to this object """ if self._file_obj is not None: @@ -914,14 +932,14 @@ def isin(self, test_elements): dask='allowed', ) - def __enter__(self): + def __enter__(self: T) -> T: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self.close() -def full_like(other, fill_value, dtype=None): +def full_like(other, fill_value, dtype: Union[str, np.dtype, None] = None): """Return a new object with the same shape and type as a given object. Parameters @@ -961,7 +979,8 @@ def full_like(other, fill_value, dtype=None): raise TypeError("Expected DataArray, Dataset, or Variable") -def _full_like_variable(other, fill_value, dtype=None): +def _full_like_variable(other, fill_value, + dtype: Union[str, np.dtype, None] = None): """Inner function of full_like, where other must be a variable """ from .variable import Variable @@ -978,27 +997,28 @@ def _full_like_variable(other, fill_value, dtype=None): return Variable(dims=other.dims, data=data, attrs=other.attrs) -def zeros_like(other, dtype=None): +def zeros_like(other, dtype: Union[str, np.dtype, None] = None): """Shorthand for full_like(other, 0, dtype) """ return full_like(other, 0, dtype) -def ones_like(other, dtype=None): +def ones_like(other, dtype: Union[str, np.dtype, None] = None): """Shorthand for full_like(other, 1, dtype) """ return full_like(other, 1, dtype) -def is_np_datetime_like(dtype): +def is_np_datetime_like(dtype: Union[str, np.dtype]) -> bool: """Check if a dtype is a subclass of the numpy datetime types """ return (np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)) -def _contains_cftime_datetimes(array): - """Check if an array contains cftime.datetime objects""" +def _contains_cftime_datetimes(array) -> bool: + """Check if an array contains cftime.datetime objects + """ try: from cftime import datetime as cftime_datetime except ImportError: @@ -1015,12 +1035,14 @@ def _contains_cftime_datetimes(array): return False -def contains_cftime_datetimes(var): - """Check if an xarray.Variable contains cftime.datetime objects""" +def contains_cftime_datetimes(var) -> bool: + """Check if an xarray.Variable contains cftime.datetime objects + """ return _contains_cftime_datetimes(var.data) -def _contains_datetime_like_objects(var): +def _contains_datetime_like_objects(var) -> bool: """Check if a variable contains datetime like objects (either - np.datetime64, np.timedelta64, or cftime.datetime)""" + np.datetime64, np.timedelta64, or cftime.datetime) + """ return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a9e55159f57..82e20d123dd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1399,7 +1399,7 @@ def unstack(self, dim=None): ds = self._to_temp_dataset().unstack(dim) return self._from_temp_dataset(ds) - def transpose(self, *dims): + def transpose(self, *dims) -> 'DataArray': """Return a new DataArray object with transposed dimensions. Parameters @@ -1427,6 +1427,10 @@ def transpose(self, *dims): variable = self.variable.transpose(*dims) return self._replace(variable) + @property + def T(self) -> 'DataArray': + return self.transpose() + def drop(self, labels, dim=None): """Drop coordinates or index labels from this DataArray. diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 0df0e727303..aa2cc5a0f03 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,6 +1,5 @@ -# flake8: noqa import sys -import typing +from collections import abc import numpy as np @@ -13,6 +12,38 @@ except ImportError: # pragma: no cover dask_array_type = () -# Ensure we have some more recent additions to the typing module. -# Note that TYPE_CHECKING itself is not available on Python 3.5.1. -TYPE_CHECKING = sys.version >= '3.5.3' and typing.TYPE_CHECKING + +if sys.version < '3.5.3': + TYPE_CHECKING = False + + class _ABCDummyBrackets(type(abc.Mapping)): # abc.ABCMeta + def __getitem__(cls, name): + return cls + + class Mapping(abc.Mapping, metaclass=_ABCDummyBrackets): + pass + + class MutableMapping(abc.MutableMapping, metaclass=_ABCDummyBrackets): + pass + + class MutableSet(abc.MutableSet, metaclass=_ABCDummyBrackets): + pass + +else: + from typing import TYPE_CHECKING # noqa: F401 + + # from typing import Mapping, MutableMapping, MutableSet + + # The above confuses mypy 0.700; + # see: https://github.com/python/mypy/issues/6652 + # As a workaround, use: + # + # from typing import Mapping, MutableMapping, MutableSet + # try: + # from .pycompat import Mapping, MutableMapping, MutableSet + # except ImportError: + # pass + # + # This is only necessary in modules that define subclasses of the + # abstract collections; when only type inference is needed, one can just + # use typing also in Python 3.5.0~3.5.2 (although mypy will misbehave). diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 349c8f98dc5..94787dd35e2 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -7,15 +7,28 @@ import re import warnings from collections import OrderedDict -from collections.abc import Iterable, Mapping, MutableMapping, MutableSet +from typing import (AbstractSet, Any, Callable, Container, Dict, Hashable, + Iterable, Iterator, Optional, Sequence, + Tuple, TypeVar, cast) import numpy as np import pandas as pd from .pycompat import dask_array_type +from typing import Mapping, MutableMapping, MutableSet +try: # Fix typed collections in Python 3.5.0~3.5.2 + from .pycompat import Mapping, MutableMapping, MutableSet # noqa: F811 +except ImportError: + pass -def _check_inplace(inplace, default=False): + +K = TypeVar('K') +V = TypeVar('V') +T = TypeVar('T') + + +def _check_inplace(inplace: bool, default: bool = False) -> bool: if inplace is None: inplace = default else: @@ -26,16 +39,16 @@ def _check_inplace(inplace, default=False): return inplace -def alias_message(old_name, new_name): +def alias_message(old_name: str, new_name: str) -> str: return '%s has been deprecated. Use %s instead.' % (old_name, new_name) -def alias_warning(old_name, new_name, stacklevel=3): +def alias_warning(old_name: str, new_name: str, stacklevel: int = 3) -> None: warnings.warn(alias_message(old_name, new_name), FutureWarning, stacklevel=stacklevel) -def alias(obj, old_name): +def alias(obj: Callable[..., T], old_name: str) -> Callable[..., T]: assert isinstance(old_name, str) @functools.wraps(obj) @@ -46,7 +59,7 @@ def wrapper(*args, **kwargs): return wrapper -def _maybe_cast_to_cftimeindex(index): +def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: from ..coding.cftimeindex import CFTimeIndex if index.dtype == 'O': @@ -58,7 +71,7 @@ def _maybe_cast_to_cftimeindex(index): return index -def safe_cast_to_index(array): +def safe_cast_to_index(array: Any) -> pd.Index: """Given an array, safely cast it to a pandas.Index. If it is already a pandas.Index, return it unchanged. @@ -79,7 +92,9 @@ def safe_cast_to_index(array): return _maybe_cast_to_cftimeindex(index) -def multiindex_from_product_levels(levels, names=None): +def multiindex_from_product_levels(levels: Sequence[pd.Index], + names: Optional[Sequence[str]] = None + ) -> pd.MultiIndex: """Creating a MultiIndex from a product without refactorizing levels. Keeping levels the same gives back the original labels when we unstack. @@ -117,7 +132,7 @@ def maybe_wrap_array(original, new_array): return new_array -def equivalent(first, second): +def equivalent(first: T, second: T) -> bool: """Compare two objects for equivalence (identity or equality), using array_equiv if either object is an ndarray """ @@ -131,8 +146,8 @@ def equivalent(first, second): (pd.isnull(first) and pd.isnull(second))) -def peek_at(iterable): - """Returns the first value from iterable, as well as a new iterable with +def peek_at(iterable: Iterable[T]) -> Tuple[T, Iterator[T]]: + """Returns the first value from iterable, as well as a new iterator with the same content as the original iterable """ gen = iter(iterable) @@ -140,7 +155,9 @@ def peek_at(iterable): return peek, itertools.chain([peek], gen) -def update_safety_check(first_dict, second_dict, compat=equivalent): +def update_safety_check(first_dict: MutableMapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent) -> None: """Check the safety of updating one dictionary with another. Raises ValueError if dictionaries have non-compatible values for any key, @@ -162,7 +179,10 @@ def update_safety_check(first_dict, second_dict, compat=equivalent): 'overriding values; conflicting key %r' % k) -def remove_incompatible_items(first_dict, second_dict, compat=equivalent): +def remove_incompatible_items(first_dict: MutableMapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent + ) -> None: """Remove incompatible items from the first dictionary in-place. Items are retained if their keys are found in both dictionaries and the @@ -177,21 +197,22 @@ def remove_incompatible_items(first_dict, second_dict, compat=equivalent): checks for equivalence. """ for k in list(first_dict): - if (k not in second_dict or - (k in second_dict and - not compat(first_dict[k], second_dict[k]))): + if k not in second_dict or not compat(first_dict[k], second_dict[k]): del first_dict[k] -def is_dict_like(value): +def is_dict_like(value: Any) -> bool: return hasattr(value, 'keys') and hasattr(value, '__getitem__') -def is_full_slice(value): +def is_full_slice(value: Any) -> bool: return isinstance(value, slice) and value == slice(None) -def either_dict_or_kwargs(pos_kwargs, kw_kwargs, func_name): +def either_dict_or_kwargs(pos_kwargs: Optional[Mapping[Hashable, T]], + kw_kwargs: Mapping[str, T], + func_name: str + ) -> Mapping[Hashable, T]: if pos_kwargs is not None: if not is_dict_like(pos_kwargs): raise ValueError('the first argument to .%s must be a dictionary' @@ -201,10 +222,12 @@ def either_dict_or_kwargs(pos_kwargs, kw_kwargs, func_name): 'arguments to .%s' % func_name) return pos_kwargs else: - return kw_kwargs + # Need an explicit cast to appease mypy due to invariance; see + # https://github.com/python/mypy/issues/6228 + return cast(Mapping[Hashable, T], kw_kwargs) -def is_scalar(value): +def is_scalar(value: Any) -> bool: """Whether to treat a value as a scalar. Any non-iterable, string, or 0-D array @@ -215,7 +238,7 @@ def is_scalar(value): isinstance(value, (Iterable, ) + dask_array_type)) -def is_valid_numpy_dtype(dtype): +def is_valid_numpy_dtype(dtype: Any) -> bool: try: np.dtype(dtype) except (TypeError, ValueError): @@ -224,15 +247,17 @@ def is_valid_numpy_dtype(dtype): return True -def to_0d_object_array(value): - """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" +def to_0d_object_array(value: Any) -> np.ndarray: + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object. + """ result = np.empty((), dtype=object) result[()] = value return result -def to_0d_array(value): - """Given a value, wrap it in a 0-D numpy.ndarray.""" +def to_0d_array(value: Any) -> np.ndarray: + """Given a value, wrap it in a 0-D numpy.ndarray. + """ if np.isscalar(value) or (isinstance(value, np.ndarray) and value.ndim == 0): return np.array(value) @@ -240,7 +265,8 @@ def to_0d_array(value): return to_0d_object_array(value) -def dict_equiv(first, second, compat=equivalent): +def dict_equiv(first: Mapping[K, V], second: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent) -> bool: """Test equivalence of two dict-like objects. If any of the values are numpy arrays, compare them correctly. @@ -266,7 +292,10 @@ def dict_equiv(first, second, compat=equivalent): return True -def ordered_dict_intersection(first_dict, second_dict, compat=equivalent): +def ordered_dict_intersection(first_dict: Mapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent + ) -> MutableMapping[K, V]: """Return the intersection of two dictionaries as a new OrderedDict. Items are retained if their keys are found in both dictionaries and the @@ -290,11 +319,10 @@ def ordered_dict_intersection(first_dict, second_dict, compat=equivalent): return new_dict -class SingleSlotPickleMixin(object): +class SingleSlotPickleMixin: """Mixin class to add the ability to pickle objects whose state is defined by a single __slots__ attribute. Only necessary under Python 2. """ - def __getstate__(self): return getattr(self, self.__slots__[0]) @@ -302,160 +330,123 @@ def __setstate__(self, state): setattr(self, self.__slots__[0], state) -class Frozen(Mapping, SingleSlotPickleMixin): +class Frozen(Mapping[K, V], SingleSlotPickleMixin): """Wrapper around an object implementing the mapping interface to make it immutable. If you really want to modify the mapping, the mutable version is saved under the `mapping` attribute. """ __slots__ = ['mapping'] - def __init__(self, mapping): + def __init__(self, mapping: Mapping[K, V]): self.mapping = mapping - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: return self.mapping[key] - def __iter__(self): + def __iter__(self) -> Iterator[K]: return iter(self.mapping) - def __len__(self): + def __len__(self) -> int: return len(self.mapping) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return key in self.mapping - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (type(self).__name__, self.mapping) -def FrozenOrderedDict(*args, **kwargs): +def FrozenOrderedDict(*args, **kwargs) -> Frozen: return Frozen(OrderedDict(*args, **kwargs)) -class SortedKeysDict(MutableMapping, SingleSlotPickleMixin): +class SortedKeysDict(MutableMapping[K, V], SingleSlotPickleMixin): """An wrapper for dictionary-like objects that always iterates over its items in sorted order by key but is otherwise equivalent to the underlying mapping. """ __slots__ = ['mapping'] - def __init__(self, mapping=None): + def __init__(self, mapping: Optional[MutableMapping[K, V]] = None): self.mapping = {} if mapping is None else mapping - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: return self.mapping[key] - def __setitem__(self, key, value): + def __setitem__(self, key: K, value: V) -> None: self.mapping[key] = value - def __delitem__(self, key): + def __delitem__(self, key: K) -> None: del self.mapping[key] - def __iter__(self): + def __iter__(self) -> Iterator[K]: return iter(sorted(self.mapping)) - def __len__(self): + def __len__(self) -> int: return len(self.mapping) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return key in self.mapping - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (type(self).__name__, self.mapping) - def copy(self): - return type(self)(self.mapping.copy()) - - -class ChainMap(MutableMapping, SingleSlotPickleMixin): - """Partial backport of collections.ChainMap from Python>=3.3 - - Don't return this from any public APIs, since some of the public methods - for a MutableMapping are missing (they will raise a NotImplementedError) - """ - __slots__ = ['maps'] - - def __init__(self, *maps): - self.maps = maps - - def __getitem__(self, key): - for mapping in self.maps: - try: - return mapping[key] - except KeyError: - pass - raise KeyError(key) - def __setitem__(self, key, value): - self.maps[0][key] = value - - def __delitem__(self, value): # pragma: no cover - raise NotImplementedError - - def __iter__(self): - seen = set() - for mapping in self.maps: - for item in mapping: - if item not in seen: - yield item - seen.add(item) - - def __len__(self): - raise len(iter(self)) - - -class OrderedSet(MutableSet): +class OrderedSet(MutableSet[T]): """A simple ordered set. The API matches the builtin set, but it preserves insertion order of elements, like an OrderedDict. """ - - def __init__(self, values=None): - self._ordered_dict = OrderedDict() + def __init__(self, values: Optional[AbstractSet[T]] = None): + self._ordered_dict = OrderedDict() # type: MutableMapping[T, None] if values is not None: - self |= values + # Disable type checking - both mypy and PyCharm believes that + # we're altering the type of self in place (see signature of + # MutableSet.__ior__) + self |= values # type: ignore # Required methods for MutableSet - def __contains__(self, value): + def __contains__(self, value: object) -> bool: return value in self._ordered_dict - def __iter__(self): + def __iter__(self) -> Iterator[T]: return iter(self._ordered_dict) - def __len__(self): + def __len__(self) -> int: return len(self._ordered_dict) - def add(self, value): + def add(self, value: T) -> None: self._ordered_dict[value] = None - def discard(self, value): + def discard(self, value: T) -> None: del self._ordered_dict[value] # Additional methods - def update(self, values): - self |= values + def update(self, values: AbstractSet[T]) -> None: + # See comment on __init__ re. type checking + self |= values # type: ignore - def __repr__(self): + def __repr__(self) -> str: return '%s(%r)' % (type(self).__name__, list(self)) -class NdimSizeLenMixin(object): +class NdimSizeLenMixin: """Mixin class that extends a class that defines a ``shape`` property to one that also defines ``ndim``, ``size`` and ``__len__``. """ @property - def ndim(self): + def ndim(self: Any) -> int: return len(self.shape) @property - def size(self): + def size(self: Any) -> int: # cast to int so that shape = () gives size = 1 return int(np.prod(self.shape)) - def __len__(self): + def __len__(self: Any) -> int: try: return self.shape[0] except IndexError: @@ -470,27 +461,27 @@ class NDArrayMixin(NdimSizeLenMixin): `dtype`, `shape` and `__getitem__`. """ @property - def dtype(self): + def dtype(self: Any) -> np.dtype: return self.array.dtype @property - def shape(self): + def shape(self: Any) -> Tuple[int]: return self.array.shape - def __getitem__(self, key): + def __getitem__(self: Any, key): return self.array[key] - def __repr__(self): + def __repr__(self: Any) -> str: return '%s(array=%r)' % (type(self).__name__, self.array) -class ReprObject(object): - """Object that prints as the given value, for use with sentinel values.""" - +class ReprObject: + """Object that prints as the given value, for use with sentinel values. + """ def __init__(self, value: str): self._value = value - def __repr__(self): + def __repr__(self) -> str: return self._value @@ -506,16 +497,16 @@ def close_on_error(f): raise -def is_remote_uri(path): +def is_remote_uri(path: str) -> bool: return bool(re.search(r'^https?\://', path)) -def is_grib_path(path): +def is_grib_path(path: str) -> bool: _, ext = os.path.splitext(path) return ext in ['.grib', '.grb', '.grib2', '.grb2'] -def is_uniform_spaced(arr, **kwargs): +def is_uniform_spaced(arr, **kwargs) -> bool: """Return True if values of an array are uniformly spaced and sorted. >>> is_uniform_spaced(range(5)) @@ -527,11 +518,12 @@ def is_uniform_spaced(arr, **kwargs): """ arr = np.array(arr, dtype=float) diffs = np.diff(arr) - return np.isclose(diffs.min(), diffs.max(), **kwargs) + return bool(np.isclose(diffs.min(), diffs.max(), **kwargs)) -def hashable(v): - """Determine whether `v` can be hashed.""" +def hashable(v: Any) -> bool: + """Determine whether `v` can be hashed. + """ try: hash(v) except TypeError: @@ -543,9 +535,10 @@ def not_implemented(*args, **kwargs): return NotImplemented -def decode_numpy_dict_values(attrs): +def decode_numpy_dict_values(attrs: Mapping[K, V]) -> Dict[K, V]: """Convert attribute values from numpy objects to native Python objects, - for use in to_dict""" + for use in to_dict + """ attrs = dict(attrs) for k, v in attrs.items(): if isinstance(v, np.ndarray): @@ -565,46 +558,43 @@ def ensure_us_time_resolution(val): return val -class HiddenKeyDict(MutableMapping): - ''' - Acts like a normal dictionary, but hides certain keys. - ''' +class HiddenKeyDict(MutableMapping[K, V]): + """Acts like a normal dictionary, but hides certain keys. + """ # ``__init__`` method required to create instance from class. - def __init__(self, data, hidden_keys): + def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]): self._data = data - if type(hidden_keys) not in (list, tuple): - raise TypeError("hidden_keys must be a list or tuple") - self._hidden_keys = hidden_keys + self._hidden_keys = frozenset(hidden_keys) - def _raise_if_hidden(self, key): + def _raise_if_hidden(self, key: K) -> None: if key in self._hidden_keys: raise KeyError('Key `%r` is hidden.' % key) # The next five methods are requirements of the ABC. - def __setitem__(self, key, value): + def __setitem__(self, key: K, value: V) -> None: self._raise_if_hidden(key) self._data[key] = value - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: self._raise_if_hidden(key) return self._data[key] - def __delitem__(self, key): + def __delitem__(self, key: K) -> None: self._raise_if_hidden(key) del self._data[key] - def __iter__(self): + def __iter__(self) -> Iterator[K]: for k in self._data: if k not in self._hidden_keys: yield k - def __len__(self): - num_hidden = sum([k in self._hidden_keys for k in self._data]) + def __len__(self) -> int: + num_hidden = len(self._hidden_keys & self._data.keys()) return len(self._data) - num_hidden -def get_temp_dimname(dims, new_dim): +def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: """ Get an new dimension name based on new_dim, that is not used in dims. If the same name exists, we add an underscore(s) in the head. @@ -618,5 +608,5 @@ def get_temp_dimname(dims, new_dim): -> ['__rolling'] """ while new_dim in dims: - new_dim = '_' + new_dim + new_dim = '_' + str(new_dim) return new_dim diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d6b64e7d458..96c6b7bd59b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1121,7 +1121,7 @@ def roll(self, shifts=None, **shifts_kwargs): result = result._roll_one_dim(dim, count) return result - def transpose(self, *dims): + def transpose(self, *dims) -> 'Variable': """Return a new Variable object with transposed dimensions. Parameters @@ -1155,6 +1155,10 @@ def transpose(self, *dims): return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) + @property + def T(self) -> 'Variable': + return self.transpose() + def expand_dims(self, *args): import warnings warnings.warn('Variable.expand_dims is deprecated: use ' diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index e98ab5cde4c..caff78b4fb1 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -5,15 +5,12 @@ import pandas as pd import pytest -import xarray as xr from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils from xarray.core.utils import either_dict_or_kwargs -from xarray.testing import assert_identical from . import ( - assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_cftime, - requires_dask) + assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_dask) from .test_coding_times import _all_cftime_date_types @@ -178,19 +175,6 @@ def test_sorted_keys_dict(self): assert repr(utils.SortedKeysDict()) == \ "SortedKeysDict({})" - def test_chain_map(self): - m = utils.ChainMap({'x': 0, 'y': 1}, {'x': -100, 'z': 2}) - assert 'x' in m - assert 'y' in m - assert 'z' in m - assert m['x'] == 0 - assert m['y'] == 1 - assert m['z'] == 2 - m['x'] = 100 - assert m['x'] == 100 - assert m.maps[0]['x'] == 100 - assert set(m) == {'x', 'y', 'z'} - def test_repr_object(): obj = utils.ReprObject('foo')