diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index ade2869ea3f..ec17f9d9b80 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -3,12 +3,14 @@ """ from functools import partial +from typing import cast import numpy as np import pandas as pd import pytest import xarray as xr +from xarray.core.dataset import Dataset pytest.importorskip("hypothesis") import hypothesis.extra.numpy as npst # isort:skip @@ -88,10 +90,10 @@ def test_roundtrip_dataarray(data, arr) -> None: @given(datasets_1d_vars()) -def test_roundtrip_dataset(dataset) -> None: +def test_roundtrip_dataset(dataset: Dataset) -> None: df = dataset.to_dataframe() assert isinstance(df, pd.DataFrame) - roundtripped = xr.Dataset(df) + roundtripped = xr.Dataset.from_dataframe(df) xr.testing.assert_identical(dataset, roundtripped) @@ -101,7 +103,7 @@ def test_roundtrip_pandas_series(ser, ix_name) -> None: ser.index.name = ix_name arr = xr.DataArray(ser) roundtripped = arr.to_pandas() - pd.testing.assert_series_equal(ser, roundtripped) + pd.testing.assert_series_equal(ser, roundtripped) # type: ignore[arg-type] xr.testing.assert_identical(arr, roundtripped.to_xarray()) @@ -119,7 +121,7 @@ def test_roundtrip_pandas_dataframe(df) -> None: df.columns.name = "cols" arr = xr.DataArray(df) roundtripped = arr.to_pandas() - pd.testing.assert_frame_equal(df, roundtripped) + pd.testing.assert_frame_equal(df, cast(pd.DataFrame, roundtripped)) xr.testing.assert_identical(arr, roundtripped.to_xarray()) @@ -143,8 +145,8 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: pd.arrays.IntervalArray( [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] ), - pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), - pd.arrays.DatetimeArray._from_sequence( + pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), # type: ignore[attr-defined] + pd.arrays.DatetimeArray._from_sequence( # type: ignore[attr-defined] pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D") ), np.array([1, 2, 3], dtype="int64"), diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index c959a7f2536..9f48f0f69e7 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,15 +1,21 @@ from __future__ import annotations import functools -from typing import Any +from collections.abc import Iterable +from typing import TYPE_CHECKING, TypeVar, cast import numpy as np +from pandas.api.extensions import ExtensionDtype from pandas.api.types import is_extension_array_dtype from xarray.compat import array_api_compat, npcompat from xarray.compat.npcompat import HAS_STRING_DTYPE from xarray.core import utils +if TYPE_CHECKING: + from typing import Any + + # Use as a sentinel value to indicate a dtype appropriate NA value. NA = utils.ReprObject("") @@ -47,8 +53,10 @@ def __eq__(self, other): (np.bytes_, np.str_), # numpy promotes to unicode ) +T_dtype = TypeVar("T_dtype", np.dtype, ExtensionDtype) -def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: + +def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]: """Simpler equivalent of pandas.core.common._maybe_promote Parameters @@ -63,7 +71,13 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: # N.B. these casting rules should match pandas dtype_: np.typing.DTypeLike fill_value: Any - if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): + if is_extension_array_dtype(dtype): + return dtype, cast(ExtensionDtype, dtype).na_value # type: ignore[redundant-cast] + if not isinstance(dtype, np.dtype): + raise TypeError( + f"dtype {dtype} must be one of an extension array dtype or numpy dtype" + ) + elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): # for now, we always promote string dtypes to object for consistency with existing behavior # TODO: refactor this once we have a better way to handle numpy vlen-string dtypes dtype_ = object @@ -222,23 +236,66 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: return xp.isdtype(dtype, kind) -def preprocess_types(t): - if isinstance(t, str | bytes): - return type(t) - elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and ( - np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_) - ): +def maybe_promote_to_variable_width( + array_or_dtype: np.typing.ArrayLike + | np.typing.DTypeLike + | ExtensionDtype + | str + | bytes, + *, + should_return_str_or_bytes: bool = False, +) -> np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype: + if isinstance(array_or_dtype, str | bytes): + if should_return_str_or_bytes: + return array_or_dtype + return type(array_or_dtype) + elif isinstance( + dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype + ) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)): # drop the length from numpy's fixed-width string dtypes, it is better to # recalculate # TODO(keewis): remove once the minimum version of `numpy.result_type` does this # for us return dtype.type else: - return t + return array_or_dtype + + +def should_promote_to_object( + arrays_and_dtypes: Iterable[ + np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype + ], + xp, +) -> bool: + """ + Test whether the given arrays_and_dtypes, when evaluated individually, match the + type promotion rules found in PROMOTE_TO_OBJECT. + """ + np_result_types = set() + for arr_or_dtype in arrays_and_dtypes: + try: + result_type = array_api_compat.result_type( + maybe_promote_to_variable_width(arr_or_dtype), xp=xp + ) + if isinstance(result_type, np.dtype): + np_result_types.add(result_type) + except TypeError: + # passing individual objects to xp.result_type means NEP-18 implementations won't have + # a chance to intercept special values (such as NA) that numpy core cannot handle + pass + + if np_result_types: + for left, right in PROMOTE_TO_OBJECT: + if any(np.issubdtype(t, left) for t in np_result_types) and any( + np.issubdtype(t, right) for t in np_result_types + ): + return True + + return False def result_type( - *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype, xp=None, ) -> np.dtype: """Like np.result_type, but with type promotion rules matching pandas. @@ -263,19 +320,18 @@ def result_type( if xp is None: xp = get_array_namespace(arrays_and_dtypes) - types = { - array_api_compat.result_type(preprocess_types(t), xp=xp) - for t in arrays_and_dtypes - } - if any(isinstance(t, np.dtype) for t in types): - # only check if there's numpy dtypes – the array API does not - # define the types we're checking for - for left, right in PROMOTE_TO_OBJECT: - if any(np.issubdtype(t, left) for t in types) and any( - np.issubdtype(t, right) for t in types - ): - return np.dtype(object) - + if should_promote_to_object(arrays_and_dtypes, xp): + return np.dtype(object) return array_api_compat.result_type( - *map(preprocess_types, arrays_and_dtypes), xp=xp + *map( + functools.partial( + maybe_promote_to_variable_width, + # let extension arrays handle their own str/bytes + should_return_str_or_bytes=any( + map(is_extension_array_dtype, arrays_and_dtypes) # type: ignore[arg-type] + ), + ), + arrays_and_dtypes, + ), + xp=xp, ) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e98ac0f36a1..53e8888b2af 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -28,7 +28,11 @@ from xarray.compat import dask_array_compat, dask_array_ops from xarray.compat.array_api_compat import get_array_namespace from xarray.core import dtypes, nputils -from xarray.core.extension_array import PandasExtensionArray +from xarray.core.extension_array import ( + PandasExtensionArray, + as_extension_array, + is_scalar, +) from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -255,7 +259,14 @@ def astype(data, dtype, *, xp=None, **kwargs): def asarray(data, xp=np, dtype=None): - converted = data if is_duck_array(data) else xp.asarray(data) + if is_duck_array(data): + converted = data + elif is_extension_array_dtype(dtype): + # data may or may not be an ExtensionArray, so we can't rely on + # np.asarray to call our NEP-18 handler; gotta hook it ourselves + converted = PandasExtensionArray(as_extension_array(data, dtype)) + else: + converted = xp.asarray(data, dtype=dtype) if dtype is None or converted.dtype == dtype: return converted @@ -267,27 +278,7 @@ def asarray(data, xp=np, dtype=None): def as_shared_dtype(scalars_or_arrays, xp=None): - """Cast arrays to a shared dtype using xarray's type promotion rules.""" - if any(is_extension_array_dtype(x) for x in scalars_or_arrays): - extension_array_types = [ - x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) - ] - non_nans = [x for x in scalars_or_arrays if not isna(x)] - if len(extension_array_types) == len(non_nans) and all( - isinstance(x, type(extension_array_types[0])) for x in extension_array_types - ): - return [ - x - if not isna(x) - else PandasExtensionArray( - type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) - ) - for x in scalars_or_arrays - ] - raise ValueError( - f"Cannot cast values to shared type, found values: {scalars_or_arrays}" - ) - + """Cast a arrays to a shared dtype using xarray's type promotion rules.""" # Avoid calling array_type("cupy") repeatidely in the any check array_type_cupy = array_type("cupy") if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): @@ -296,7 +287,12 @@ def as_shared_dtype(scalars_or_arrays, xp=None): xp = cp elif xp is None: xp = get_array_namespace(scalars_or_arrays) - + scalars_or_arrays = [ + PandasExtensionArray(s_or_a) + if isinstance(s_or_a, pd.api.extensions.ExtensionArray) + else s_or_a + for s_or_a in scalars_or_arrays + ] # Pass arrays directly instead of dtypes to result_type so scalars # get handled properly. # Note that result_type() safely gets the dtype from dask arrays without @@ -407,7 +403,12 @@ def where(condition, x, y): else: condition = astype(condition, dtype=dtype, xp=xp) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp) + + # pd.where won't broadcast 0-dim arrays across a series; scalar y's must be preserved + maybe_promoted_y = y if is_extension_array_dtype(x) and is_scalar(y) else promoted_y + + return xp.where(condition, promoted_x, maybe_promoted_y) def where_method(data, cond, other=dtypes.NA): diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 7cc9db96d0d..5f45e8b2e84 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -3,12 +3,14 @@ import copy from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Any, Generic, cast +from typing import TYPE_CHECKING, Generic, cast import numpy as np import pandas as pd from packaging.version import Version +from pandas.api.extensions import ExtensionArray, ExtensionDtype from pandas.api.types import is_extension_array_dtype +from pandas.api.types import is_scalar as pd_is_scalar from xarray.core.types import DTypeLikeSave, T_ExtensionArray from xarray.core.utils import NDArrayMixin @@ -16,11 +18,32 @@ HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} -def implements(numpy_function): - """Register an __array_function__ implementation for MyArray objects.""" +if TYPE_CHECKING: + from typing import Any + + from pandas._typing import DtypeObj, Scalar + + +def is_scalar(value: object) -> bool: + """Workaround: pandas is_scalar doesn't recognize Categorical nulls for some reason.""" + return value is pd.CategoricalDtype.na_value or pd_is_scalar(value) + + +def implements(numpy_function_or_name: Callable | str) -> Callable: + """Register an __array_function__ implementation. + + Pass a function directly if it's guaranteed to exist in all supported numpy versions, or a + string to first check for its existence. + """ def decorator(func): - HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + if isinstance(numpy_function_or_name, str): + numpy_function = getattr(np, numpy_function_or_name, None) + else: + numpy_function = numpy_function_or_name + + if numpy_function: + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func return func return decorator @@ -33,6 +56,110 @@ def __extension_duck_array__issubdtype( return False # never want a function to think a pandas extension dtype is a subtype of numpy +@implements("astype") # np.astype was added in 2.1.0, but we only require >=1.24 +def __extension_duck_array__astype( + array_or_scalar: T_ExtensionArray, + dtype: DTypeLikeSave, + order: str = "K", + casting: str = "unsafe", + subok: bool = True, + copy: bool = True, + device: str | None = None, +) -> ExtensionArray: + if ( + not ( + is_extension_array_dtype(array_or_scalar) or is_extension_array_dtype(dtype) + ) + or casting != "unsafe" + or not subok + or order != "K" + ): + return NotImplemented + + return as_extension_array(array_or_scalar, dtype, copy=copy) + + +@implements(np.asarray) +def __extension_duck_array__asarray( + array_or_scalar: np.typing.ArrayLike | T_ExtensionArray, + dtype: DTypeLikeSave | None = None, +) -> ExtensionArray: + if not is_extension_array_dtype(dtype): + return NotImplemented + + return as_extension_array(array_or_scalar, dtype) + + +def as_extension_array( + array_or_scalar: np.typing.ArrayLike | T_ExtensionArray, + dtype: ExtensionDtype | DTypeLikeSave | None, + copy: bool = False, +) -> ExtensionArray: + if is_scalar(array_or_scalar): + return dtype.construct_array_type()._from_sequence( # type: ignore[union-attr] + [array_or_scalar], dtype=dtype + ) + else: + return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr] + + +@implements(np.result_type) +def __extension_duck_array__result_type( + *arrays_and_dtypes: list[ + np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype | ExtensionArray + ], +) -> DtypeObj: + extension_arrays_and_dtypes: list[ExtensionDtype | ExtensionArray] = [ + cast(ExtensionDtype | ExtensionArray, x) + for x in arrays_and_dtypes + if is_extension_array_dtype(x) + ] + if not extension_arrays_and_dtypes: + return NotImplemented + + ea_dtypes: list[ExtensionDtype] = [ + getattr(x, "dtype", cast(ExtensionDtype, x)) + for x in extension_arrays_and_dtypes + ] + scalars = [ + x for x in arrays_and_dtypes if is_scalar(x) and x not in {pd.NA, np.nan} + ] + # other_stuff could include: + # - arrays such as pd.ABCSeries, np.ndarray, or other array-api duck arrays + # - dtypes such as pd.DtypeObj, np.dtype, or other array-api duck dtypes + other_stuff = [ + x + for x in arrays_and_dtypes + if not is_extension_array_dtype(x) and not is_scalar(x) + ] + # We implement one special case: when possible, preserve Categoricals (avoid promoting + # to object) by merging the categories of all given Categoricals + scalars + NA. + # Ideally this could be upstreamed into pandas find_result_type / find_common_type. + if not other_stuff and all( + isinstance(x, pd.CategoricalDtype) and not x.ordered for x in ea_dtypes + ): + return union_unordered_categorical_and_scalar( + cast(list[pd.CategoricalDtype], ea_dtypes), + scalars, # type: ignore[arg-type] + ) + if not other_stuff and all( + isinstance(x, type(ea_type := ea_dtypes[0])) for x in ea_dtypes + ): + return ea_type + raise ValueError( + f"Cannot cast values to shared type, found values: {arrays_and_dtypes}" + ) + + +def union_unordered_categorical_and_scalar( + categorical_dtypes: list[pd.CategoricalDtype], scalars: list[Scalar] +) -> pd.CategoricalDtype: + scalars = [x for x in scalars if x is not pd.CategoricalDtype.na_value] + all_categories = set().union(*(x.categories for x in categorical_dtypes)) + all_categories = all_categories.union(scalars) + return pd.CategoricalDtype(categories=list(all_categories)) + + @implements(np.broadcast_to) def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): if shape[0] == len(arr) and len(shape) == 1: @@ -54,16 +181,31 @@ def __extension_duck_array__concatenate( @implements(np.where) def __extension_duck_array__where( - condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray + condition: T_ExtensionArray | np.typing.ArrayLike, + x: T_ExtensionArray, + y: T_ExtensionArray | np.typing.ArrayLike, ) -> T_ExtensionArray: - if ( - isinstance(x, pd.Categorical) - and isinstance(y, pd.Categorical) - and x.dtype != y.dtype - ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] - y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] - return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) + return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type] + + +def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, PandasExtensionArray): + args_as_list[index] = replacer(value) + elif isinstance(value, tuple): # should handle more than just tuple? iterable? + args_as_list[index] = tuple(_replace_duck(value, replacer)) + elif isinstance(value, list): + args_as_list[index] = _replace_duck(value, replacer) + return args_as_list + + +def replace_duck_with_extension_array(args) -> tuple: + return tuple(_replace_duck(args, lambda duck: duck.array)) + + +def replace_duck_with_series(args) -> tuple: + return tuple(_replace_duck(args, lambda duck: pd.Series(duck.array))) @implements(np.ndim) @@ -107,26 +249,11 @@ def __post_init__(self): ) def __array_function__(self, func, types, args, kwargs): - def replace_duck_with_extension_array(args) -> list: - args_as_list = list(args) - for index, value in enumerate(args_as_list): - if isinstance(value, PandasExtensionArray): - args_as_list[index] = value.array - elif isinstance( - value, tuple - ): # should handle more than just tuple? iterable? - args_as_list[index] = tuple( - replace_duck_with_extension_array(value) - ) - elif isinstance(value, list): - args_as_list[index] = replace_duck_with_extension_array(value) - return args_as_list - - args = tuple(replace_duck_with_extension_array(args)) + args = replace_duck_with_extension_array(args) if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: raise KeyError("Function not registered for pandas extension arrays.") res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) - if is_extension_array_dtype(res): + if isinstance(res, ExtensionArray): return PandasExtensionArray(res) return res @@ -134,16 +261,23 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + if ( + isinstance(key, tuple) and len(key) == 1 + ): # pyarrow type arrays can't handle since-length tuples + key = key[0] item = self.array[key] if is_extension_array_dtype(item): return PandasExtensionArray(item) - if np.isscalar(item) or isinstance(key, int): + if is_scalar(item) or isinstance(key, int): return PandasExtensionArray(type(self.array)._from_sequence([item])) # type: ignore[call-arg,attr-defined,unused-ignore] return PandasExtensionArray(item) def __setitem__(self, key, val): self.array[key] = val + def __len__(self): + return len(self.array) + def __eq__(self, other): if isinstance(other, PandasExtensionArray): return self.array == other.array @@ -152,9 +286,6 @@ def __eq__(self, other): def __ne__(self, other): return ~(self == other) - def __len__(self): - return len(self.array) - @property def ndim(self) -> int: return 1 diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b2b9ae314c4..f61cff42226 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -53,6 +53,7 @@ assert_no_warnings, has_dask, has_dask_ge_2025_1_0, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cupy, @@ -61,6 +62,7 @@ requires_iris, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -1847,6 +1849,87 @@ def test_reindex_empty_array_dtype(self) -> None: "Dtype of reindexed DataArray should remain float32" ) + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical(["a", "b", "c"]), id="categorical"), + ] + + [ + pytest.param( + pd.array([1, 2, 3], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_reindex_extension_array(self, extension_array) -> None: + srs = pd.Series(index=["e", "f", "g"], data=extension_array) + x = srs.to_xarray() + y = x.reindex(index=["f", "g", "z"]) + assert_array_equal(x, extension_array) + # TODO: remove .array once the branch is updated with main + pd.testing.assert_extension_array_equal( + y.data, + extension_array._from_sequence( + [extension_array[1], extension_array[2], pd.NA], + dtype=extension_array.dtype, + ), + ) + assert x.dtype == y.dtype == extension_array.dtype + + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() + filled = da.fillna(fill_value) + assert filled.dtype == srs.dtype + assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all() + + @requires_pyarrow + def test_fillna_extension_array_bad_val(self) -> None: + srs: pd.Series = pd.Series( + index=np.array([1, 2, 3]), + data=pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + ) + da = srs.to_xarray() + with pytest.raises(ValueError): + da.fillna("a") + + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() + filled = da.dropna("index") + assert filled.dtype == srs.dtype + assert (filled.values == srs.values[1:]).all() + def test_rename(self) -> None: da = xr.DataArray( [1, 2, 3], dims="dim", name="name", coords={"coord": ("dim", [5, 6, 7])} diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8e527faa5c7..b9573700300 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -70,6 +70,7 @@ requires_dask, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -1802,28 +1803,48 @@ def test_categorical_index_reindex(self) -> None: actual = ds.reindex(cat=["foo"])["cat"].values assert (actual == np.array(["foo"])).all() - @pytest.mark.parametrize("fill_value", [np.nan, pd.NA]) - def test_extensionarray_negative_reindex(self, fill_value) -> None: - cat = pd.Categorical( - ["foo", "bar", "baz"], - categories=["foo", "bar", "baz", "qux", "quux", "corge"], - ) + @pytest.mark.parametrize("fill_value", [np.nan, pd.NA, None]) + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param( + pd.Categorical( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux"], + ), + id="categorical", + ), + ] + + [ + pytest.param( + pd.array([1, 1, None], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: ds = xr.Dataset( - {"cat": ("index", cat)}, + {"arr": ("index", extension_array)}, coords={"index": ("index", np.arange(3))}, ) + kwargs = {} + if fill_value is not None: + kwargs["fill_value"] = fill_value reindexed_cat = cast( pd.api.extensions.ExtensionArray, - ( - ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"] - .to_pandas() - .values - ), + (ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values), + ) + assert reindexed_cat.equals( # type: ignore[attr-defined] + pd.array( + [pd.NA, extension_array[1], extension_array[1]], + dtype=extension_array.dtype, + ) ) - assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined] + @requires_pyarrow def test_extension_array_reindex_same(self) -> None: - series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) + series = pd.Series([1, 2, pd.NA, 3], dtype="int32[pyarrow]") test = xr.Dataset({"test": series}) res = test.reindex(dim_0=series.index) align(res, test, join="exact") @@ -5486,6 +5507,51 @@ def test_dropna(self) -> None: with pytest.raises(TypeError, match=r"must specify how or thresh"): ds.dropna("a", how=None) # type: ignore[arg-type] + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + filled = ds.fillna(fill_value) + assert filled["data"].dtype == extension_array.dtype + assert ( + filled["data"].values + == np.array([fill_value, *srs["data"].values[1:]], dtype="object") + ).all() + + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + dropped = ds.dropna("index") + assert dropped["data"].dtype == extension_array.dtype + assert (dropped["data"].values == srs["data"].values[1:]).all() + def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index eaafe2d4536..ed8f04a87eb 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1108,6 +1108,21 @@ def test_extension_array_repr(int1): assert repr(int1) in repr(int_duck_array) +def test_extension_array_result_type_categorical(categorical1, categorical2): + res = np.result_type( + PandasExtensionArray(categorical1), PandasExtensionArray(categorical2) + ) + assert isinstance(res, pd.CategoricalDtype) + assert set(res.categories) == set(categorical1.categories) | set( + categorical2.categories + ) + assert not res.ordered + + assert categorical1.dtype == np.result_type( + PandasExtensionArray(categorical1), pd.CategoricalDtype.na_value + ) + + def test_extension_array_attr(): array = pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) wrapped = PandasExtensionArray(array)