diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 8fc32e75cbd..ade2869ea3f 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -15,6 +15,7 @@ import hypothesis.extra.pandas as pdst # isort:skip import hypothesis.strategies as st # isort:skip from hypothesis import given # isort:skip +from xarray.tests import has_pyarrow numeric_dtypes = st.one_of( npst.unsigned_integer_dtypes(endianness="="), @@ -134,10 +135,39 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: xr.testing.assert_identical(dataset, roundtripped.to_xarray()) -def test_roundtrip_1d_pandas_extension_array() -> None: - df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])}) - arr = xr.Dataset.from_dataframe(df)["cat"] +@pytest.mark.parametrize( + "extension_array", + [ + pd.Categorical(["a", "b", "c"]), + pd.array(["a", "b", "c"], dtype="string"), + pd.arrays.IntervalArray( + [pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] + ), + pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), + pd.arrays.DatetimeArray._from_sequence( + pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D") + ), + np.array([1, 2, 3], dtype="int64"), + ] + + ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []), + ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"] + + (["pyarrow"] if has_pyarrow else []), +) +@pytest.mark.parametrize("is_index", [True, False]) +def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None: + df = pd.DataFrame({"arr": extension_array}) + if is_index: + df = df.set_index("arr") + arr = xr.Dataset.from_dataframe(df)["arr"] roundtripped = arr.to_pandas() - assert (df["cat"] == roundtripped).all() - assert df["cat"].dtype == roundtripped.dtype - xr.testing.assert_identical(arr, roundtripped.to_xarray()) + df_arr_to_test = df.index if is_index else df["arr"] + assert (df_arr_to_test == roundtripped).all() + # `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes. + if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined] + assert isinstance(arr.data, np.ndarray) + else: + assert ( + df_arr_to_test.dtype + == (roundtripped.index if is_index else roundtripped).dtype + ) + xr.testing.assert_identical(arr, roundtripped.to_xarray()) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 657fb022c09..367da2f60a5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -99,6 +99,7 @@ parse_dims_as_set, ) from xarray.core.variable import ( + UNSUPPORTED_EXTENSION_ARRAY_TYPES, IndexVariable, Variable, as_variable, @@ -7281,7 +7282,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: extension_arrays = [] for k, v in dataframe.items(): if not is_extension_array_dtype(v) or isinstance( - v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray + v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES ): arrays.append((k, np.asarray(v))) else: diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 0e77bc5f409..9052f5ae0a0 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -93,6 +93,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): def __post_init__(self): if not isinstance(self.array, pd.api.extensions.ExtensionArray): raise TypeError(f"{self.array} is not an pandas ExtensionArray.") + # This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because + # we do support extension arrays from datetime, for example, that need + # duck array support internally via this class. + if isinstance(self.array, pd.arrays.NumpyExtensionArray): + raise TypeError( + "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." + ) def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 1bccc51bd43..1813a25d7af 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1802,8 +1802,12 @@ def __array__( def get_duck_array(self) -> np.ndarray | PandasExtensionArray: # We return an PandasExtensionArray wrapper type that satisfies - # duck array protocols. This is what's needed for tests to pass. - if pd.api.types.is_extension_array_dtype(self.array): + # duck array protocols. + # `NumpyExtensionArray` is excluded + if pd.api.types.is_extension_array_dtype(self.array) and not isinstance( + self.array.array, + pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined] + ): from xarray.core.extension_array import PandasExtensionArray return PandasExtensionArray(self.array.array) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 190da4b2f06..9c753a2ffa7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -63,6 +63,11 @@ ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) +UNSUPPORTED_EXTENSION_ARRAY_TYPES = ( + pd.arrays.DatetimeArray, + pd.arrays.TimedeltaArray, + pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined] +) if TYPE_CHECKING: from xarray.core.types import ( @@ -190,6 +195,8 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) + if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES): + return data.to_numpy() if isinstance(data, pd.api.extensions.ExtensionArray): return PandasExtensionArray(data) return data @@ -251,7 +258,14 @@ def convert_non_numpy_type(data): # we don't want nested self-described arrays if isinstance(data, pd.Series | pd.DataFrame): - pandas_data = data.values + if ( + isinstance(data, pd.Series) + and pd.api.types.is_extension_array_dtype(data) + and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES) + ): + pandas_data = data.array + else: + pandas_data = data.values # type: ignore[assignment] if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): return convert_non_numpy_type(pandas_data) else: