Skip to content

Support NumPy array API (experimental) #6804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 20, 2022
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs):
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)

func = getattr(np, name)
if hasattr(values, "__array_namespace__"):
xp = values.__array_namespace__()
func = getattr(xp, name)
else:
func = getattr(np, name)

try:
with warnings.catch_warnings():
Expand Down
43 changes: 43 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ def as_indexable(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down Expand Up @@ -1288,6 +1290,47 @@ def __init__(self, array):
self.array = array


class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""

__slots__ = ("array",)

def __init__(self, array):
if not hasattr(array, "__array_namespace__"):
raise TypeError(
"ArrayApiIndexingAdapter must wrap an object that "
"implements the __array_namespace__ protocol"
)
self.array = array

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value
else:
assert isinstance(key, VectorizedIndexer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit pick: since assert statements are removed when python is invoked with -O and -OO parameters, could we raise a proper exception?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should raise I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with a TypeError

raise TypeError("Vectorized indexing is not supported")

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
self.array[key.tuple] = value
elif isinstance(key, OuterIndexer):
self.array[key.tuple] = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless i'm missing something, we should be able to combine these two branches into one, correct?

Suggested change
if isinstance(key, BasicIndexer):
self.array[key.tuple] = value
elif isinstance(key, OuterIndexer):
self.array[key.tuple] = value
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct - fixed.

else:
assert isinstance(key, VectorizedIndexer)
raise TypeError("Vectorized indexing is not supported")

def transpose(self, order):
xp = self.array.__array_namespace__()
return xp.permute_dims(self.array, order)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down
7 changes: 5 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool:
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and hasattr(value, "__array_function__")
and hasattr(value, "__array_ufunc__")
and (
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__")
)
)


Expand Down Expand Up @@ -298,6 +300,7 @@ def _is_scalar(value, include_0d):
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
or hasattr(value, "__array_namespace__")
)
)

Expand Down
6 changes: 5 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def as_compatible_data(data, fastpath=False):
if isinstance(data, (Variable, DataArray)):
return data.data

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
if (
isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(data, "__array_function__")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an __array_function__ check down below (lines 245, so it seems like this change should go there).

I suspect we can also delete cupy, dask from NON_NUMPY_SUPPORTED_ARRAY_TYPES

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an __array_function__ check down below (lines 245, so it seems like this change should go there).

Good point - thanks for the suggestion! I've moved the change to down below.

I suspect we can also delete cupy, dask from NON_NUMPY_SUPPORTED_ARRAY_TYPES

I haven't tried this. Sounds like this would be a separate change?

or hasattr(data, "__array_namespace__")
):
return _maybe_wrap_data(data)

if isinstance(data, tuple):
Expand Down
48 changes: 48 additions & 0 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy.array_api as xp
import pytest
from numpy.array_api._array_object import Array

import xarray as xr
from xarray.testing import assert_equal

np = pytest.importorskip("numpy", minversion="1.22")


@pytest.fixture
def arrays():
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
assert isinstance(xp_arr.data, Array)
return np_arr, xp_arr


def test_arithmetic(arrays):
np_arr, xp_arr = arrays
expected = np_arr + 7
actual = xp_arr + 7
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_aggregation(arrays):
np_arr, xp_arr = arrays
expected = np_arr.sum(skipna=False)
actual = xp_arr.sum(skipna=False)
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_indexing(arrays):
np_arr, xp_arr = arrays
expected = np_arr[:, 0]
actual = xp_arr[:, 0]
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_reorganizing_operation(arrays):
np_arr, xp_arr = arrays
expected = np_arr.transpose()
actual = xp_arr.transpose()
assert isinstance(actual.data, Array)
assert_equal(actual, expected)