Skip to content

(fix): disallow NumpyExtensionArray #10334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,26 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
@pytest.mark.parametrize(
"extension_array",
[
pd.Categorical(["a", "b", "c"]),
pd.array([1, 2, 3], dtype="int64"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pd.array([1, 2, 3], dtype="int64"),
pd.array([1, 2, 3], dtype="int64"),
pd.array([1, 2, 3], dtype="int64[pyarrow]"),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more exhaustive please? with datetime, timedelta etc.

Copy link
Contributor Author

@ilan-gold ilan-gold May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without opening a can of worms, datetime and timedelta are the two things that are currently (on main) whitelisted (without that, they don't roundtrip).

import pandas as pd

df = pd.DataFrame({"arr": pd.date_range("20130101", periods=4, tz="US/Eastern")})
df.to_xarray()["arr"].to_pandas()

loses the timezone, for example, without the whitelist on from_dataframe.

So I don't think we want to add those here, which is also why I whitelisted them elsewhere. I was hoping that before adding tests, we would settle on what exactly this PR should be doing (of course NumpyExtensionArray but should it cover anything else).

I am happy to roll back that whitelist, i.e., leave it where it was on from_dataframe and then allow these types through anyway via other means into a Dataset.

Copy link
Contributor

@dcherian dcherian May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd simply like a test or two that exhaustively records the current behaviour (whether cast to numpy or not), so we can be sure of what is going on. Two tests (one for preserved dtype, and one for numpy casting) would work fine.

There's also two cases: ExtensionArray in a data variable, and ExtensionArray as an indexed coordinate variable.

Copy link
Contributor Author

@ilan-gold ilan-gold May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two tests (one for preserved dtype, and one for numpy casting) would work fine.

Not 100% certain what this meant - could you clarify?

There's also two cases: ExtensionArray in a data variable, and ExtensionArray as an indexed coordinate variable.

I think this covered here. When I check xr.Dataset.from_dataframe(df)["arr"].variable it's an IndexVariable as expected when it's an index on the original object

Copy link
Contributor Author

@ilan-gold ilan-gold May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check out ilan-gold#2 for a related cleanup. As is here, the data types already round trip so that PR only cleans up the internals a bit and the tests/behavior are also a bit clearer now (better preserved date types). Other than that, no changes in terms of types or behavior, it seems

pd.array(["a", "b", "c"], dtype="string"),
pd.arrays.IntervalArray(
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
),
np.array([1, 2, 3], dtype="int64"),
],
)
def test_roundtrip_1d_pandas_extension_array(extension_array) -> None:
df = pd.DataFrame({"arr": extension_array})
arr = xr.Dataset.from_dataframe(df)["arr"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
assert (df["arr"] == roundtripped).all()
# `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes.
if isinstance(extension_array, pd.arrays.NumpyExtensionArray):
assert isinstance(arr.data, np.ndarray)
else:
assert df["arr"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
parse_dims_as_set,
)
from xarray.core.variable import (
UNSUPPORTED_EXTENSION_ARRAY_TYPES,
IndexVariable,
Variable,
as_variable,
Expand Down Expand Up @@ -7271,7 +7272,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v) or isinstance(
v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
):
arrays.append((k, np.asarray(v)))
else:
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
def __post_init__(self):
if not isinstance(self.array, pd.api.extensions.ExtensionArray):
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
# we do support extension arrays from datetime, for example, that need
# duck array support internally via this class.
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
raise TypeError(
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
)

def __array_function__(self, func, types, args, kwargs):
def replace_duck_with_extension_array(args) -> list:
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,8 +1802,11 @@ def __array__(

def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
# We return an PandasExtensionArray wrapper type that satisfies
# duck array protocols. This is what's needed for tests to pass.
if pd.api.types.is_extension_array_dtype(self.array):
# duck array protocols.
# `NumpyExtensionArray` is excluded
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
self.array.array, pd.arrays.NumpyExtensionArray
):
from xarray.core.extension_array import PandasExtensionArray

return PandasExtensionArray(self.array.array)
Expand Down
16 changes: 15 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@
)
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)
UNSUPPORTED_EXTENSION_ARRAY_TYPES = (
pd.arrays.DatetimeArray,
pd.arrays.TimedeltaArray,
pd.arrays.NumpyExtensionArray,
)

if TYPE_CHECKING:
from xarray.core.types import (
Expand Down Expand Up @@ -191,6 +196,8 @@ def _maybe_wrap_data(data):
"""
if isinstance(data, pd.Index):
return PandasIndexingAdapter(data)
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
return data.to_numpy()
if isinstance(data, pd.api.extensions.ExtensionArray):
return PandasExtensionArray(data)
return data
Expand Down Expand Up @@ -252,7 +259,14 @@ def convert_non_numpy_type(data):

# we don't want nested self-described arrays
if isinstance(data, pd.Series | pd.DataFrame):
pandas_data = data.values
if (
isinstance(data, pd.Series)
and pd.api.types.is_extension_array_dtype(data)
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
):
pandas_data = data.array
else:
pandas_data = data.values
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(pandas_data)
else:
Expand Down
Loading