-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support NumPy array API (experimental) #6804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
dc76862
14ef443
ffbb602
fe33202
0da1f5e
7e91bd5
fa9ea14
3cc3cb4
f6df255
afe3d9f
0f81209
6d7f13e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -679,6 +679,8 @@ def as_indexable(array): | |||||||||||||
return DaskIndexingAdapter(array) | ||||||||||||||
if hasattr(array, "__array_function__"): | ||||||||||||||
return NdArrayLikeIndexingAdapter(array) | ||||||||||||||
if hasattr(array, "__array_namespace__"): | ||||||||||||||
return ArrayApiIndexingAdapter(array) | ||||||||||||||
|
||||||||||||||
raise TypeError(f"Invalid array type: {type(array)}") | ||||||||||||||
|
||||||||||||||
|
@@ -1288,6 +1290,47 @@ def __init__(self, array): | |||||||||||||
self.array = array | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): | ||||||||||||||
"""Wrap an array API array to use explicit indexing.""" | ||||||||||||||
|
||||||||||||||
__slots__ = ("array",) | ||||||||||||||
|
||||||||||||||
def __init__(self, array): | ||||||||||||||
if not hasattr(array, "__array_namespace__"): | ||||||||||||||
raise TypeError( | ||||||||||||||
"ArrayApiIndexingAdapter must wrap an object that " | ||||||||||||||
"implements the __array_namespace__ protocol" | ||||||||||||||
) | ||||||||||||||
self.array = array | ||||||||||||||
|
||||||||||||||
def __getitem__(self, key): | ||||||||||||||
if isinstance(key, BasicIndexer): | ||||||||||||||
return self.array[key.tuple] | ||||||||||||||
elif isinstance(key, OuterIndexer): | ||||||||||||||
# manual orthogonal indexing (implemented like DaskIndexingAdapter) | ||||||||||||||
key = key.tuple | ||||||||||||||
value = self.array | ||||||||||||||
for axis, subkey in reversed(list(enumerate(key))): | ||||||||||||||
value = value[(slice(None),) * axis + (subkey, Ellipsis)] | ||||||||||||||
return value | ||||||||||||||
else: | ||||||||||||||
assert isinstance(key, VectorizedIndexer) | ||||||||||||||
raise TypeError("Vectorized indexing is not supported") | ||||||||||||||
|
||||||||||||||
def __setitem__(self, key, value): | ||||||||||||||
if isinstance(key, BasicIndexer): | ||||||||||||||
self.array[key.tuple] = value | ||||||||||||||
elif isinstance(key, OuterIndexer): | ||||||||||||||
self.array[key.tuple] = value | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unless i'm missing something, we should be able to combine these two branches into one, correct?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct - fixed. |
||||||||||||||
else: | ||||||||||||||
assert isinstance(key, VectorizedIndexer) | ||||||||||||||
raise TypeError("Vectorized indexing is not supported") | ||||||||||||||
|
||||||||||||||
def transpose(self, order): | ||||||||||||||
xp = self.array.__array_namespace__() | ||||||||||||||
return xp.permute_dims(self.array, order) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): | ||||||||||||||
"""Wrap a dask array to support explicit indexing.""" | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,7 +211,11 @@ def as_compatible_data(data, fastpath=False): | |
if isinstance(data, (Variable, DataArray)): | ||
return data.data | ||
|
||
if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): | ||
if ( | ||
isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES) | ||
or hasattr(data, "__array_function__") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's an I suspect we can also delete cupy, dask from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good point - thanks for the suggestion! I've moved the change to down below.
I haven't tried this. Sounds like this would be a separate change? |
||
or hasattr(data, "__array_namespace__") | ||
): | ||
return _maybe_wrap_data(data) | ||
|
||
if isinstance(data, tuple): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy.array_api as xp | ||
Illviljan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import pytest | ||
from numpy.array_api._array_object import Array | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import xarray as xr | ||
from xarray.testing import assert_equal | ||
|
||
np = pytest.importorskip("numpy", minversion="1.22") | ||
|
||
|
||
@pytest.fixture | ||
def arrays(): | ||
tomwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) | ||
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) | ||
assert isinstance(xp_arr.data, Array) | ||
return np_arr, xp_arr | ||
|
||
|
||
def test_arithmetic(arrays): | ||
tomwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np_arr, xp_arr = arrays | ||
expected = np_arr + 7 | ||
actual = xp_arr + 7 | ||
assert isinstance(actual.data, Array) | ||
assert_equal(actual, expected) | ||
|
||
|
||
def test_aggregation(arrays): | ||
tomwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np_arr, xp_arr = arrays | ||
expected = np_arr.sum(skipna=False) | ||
actual = xp_arr.sum(skipna=False) | ||
assert isinstance(actual.data, Array) | ||
assert_equal(actual, expected) | ||
|
||
|
||
def test_indexing(arrays): | ||
tomwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np_arr, xp_arr = arrays | ||
expected = np_arr[:, 0] | ||
actual = xp_arr[:, 0] | ||
assert isinstance(actual.data, Array) | ||
assert_equal(actual, expected) | ||
|
||
|
||
def test_reorganizing_operation(arrays): | ||
tomwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np_arr, xp_arr = arrays | ||
expected = np_arr.transpose() | ||
actual = xp_arr.transpose() | ||
assert isinstance(actual.data, Array) | ||
assert_equal(actual, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit pick: since
assert
statements are removed when python is invoked with-O
and-OO
parameters, could we raise a proper exception?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we should raise I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with a
TypeError