Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ New Features
``pip install git+https://github.com/andrewgsavage/pint.git@refs/pull/6/head)``.
Even with it, interaction with non-numpy array libraries, e.g. dask or sparse, is broken.

- :py:func:`~xarray.dot`, and :py:func:`~xarray.DataArray.dot` now support the
``dims=xarray.ALL_DIMS`` option to sum over the union of dimensions of all
arrays (:issue:`3423`) by `Mathias Hauser <https://github.com/mathause>`_.

Bug fixes
~~~~~~~~~
- Fix regression introduced in v0.14.0 that would cause a crash if dask is installed
Expand Down
17 changes: 14 additions & 3 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from . import duck_array_ops, utils
from .alignment import deep_align
from .common import ALL_DIMS
from .merge import merge_coordinates_without_align
from .pycompat import dask_array_type
from .utils import is_dict_like
Expand Down Expand Up @@ -1055,7 +1056,7 @@ def dot(*arrays, dims=None, **kwargs):
----------
arrays: DataArray (or Variable) objects
Arrays to compute.
dims: str or tuple of strings, optional
dims: xarray.ALL_DIMS, str or tuple of strings, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dims: xarray.ALL_DIMS, str or tuple of strings, optional
dims: '...', str or tuple of strings, optional

Which dimensions to sum over.
If not speciified, then all the common dimensions are summed over.
**kwargs: dict
Expand All @@ -1070,7 +1071,7 @@ def dot(*arrays, dims=None, **kwargs):
--------

>>> import numpy as np
>>> import xarray as xp
>>> import xarray as xr
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=['a', 'b'])
>>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2),
... dims=['a', 'b', 'c'])
Expand Down Expand Up @@ -1117,6 +1118,14 @@ def dot(*arrays, dims=None, **kwargs):
[273, 446, 619]])
Dimensions without coordinates: a, d

>>> xr.dot(da_a, da_b)
<xarray.DataArray (c: 2)>
array([110, 125])
Dimensions without coordinates: c

>>> xr.dot(da_a, da_b, dims=xr.ALL_DIMS)
<xarray.DataArray ()>
array(235)
"""
from .dataarray import DataArray
from .variable import Variable
Expand All @@ -1141,7 +1150,9 @@ def dot(*arrays, dims=None, **kwargs):
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is None:
if dims is ALL_DIMS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if dims is ALL_DIMS:
if dims is ...:

dims = all_dims
elif dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
for arr in arrays:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,7 +2747,7 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims: hashable or sequence of hashables, optional
dims: xarray.ALL_DIMS, hashable or sequence of hashables, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dims: xarray.ALL_DIMS, hashable or sequence of hashables, optional
dims: '...', hashable or sequence of hashables, optional

Along which dimensions to be summed over. Default all the common
dimensions are summed over.

Expand Down
17 changes: 17 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,23 @@ def test_dot(use_dask):
assert actual.dims == ("b",)
assert (actual.data == np.zeros(actual.shape)).all()

# xr.ALL_DIMS
actual = xr.dot(da_a, da_b, dims=xr.ALL_DIMS)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=xr.ALL_DIMS)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()

actual = xr.dot(da_a, dims=xr.ALL_DIMS)
assert actual.dims == ()
assert (actual.data == np.einsum("ij-> ", a)).all()

actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=xr.ALL_DIMS)
assert actual.dims == ()
assert (actual.data == np.zeros(actual.shape)).all()

# Invalid cases
if not use_dask:
with pytest.raises(TypeError):
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3930,6 +3930,16 @@ def test_dot(self):
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
assert_equal(expected, actual)

# xr.ALL_DIMS: all shared dims
actual = da.dot(da, dims=xr.ALL_DIMS)
expected = da.dot(da)
assert_equal(expected, actual)

# xr.ALL_DIMS: multiple shared dims
actual = da.dot(dm, dims=xr.ALL_DIMS)
expected = da.dot(dm, dims=("j", "x", "y", "z"))
assert_equal(expected, actual)

with pytest.raises(NotImplementedError):
da.dot(dm.to_dataset(name="dm"))
with pytest.raises(TypeError):
Expand Down