Skip to content
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Top-level functions
full_like
zeros_like
ones_like
dot

Dataset
=======
Expand Down
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ Documentation
Enhancements
~~~~~~~~~~~~

- Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``.
Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option,
which specifies the dimensions to sum over.
(:issue:`1951`)
- Support lazy vectorized-indexing. After this change, flexible indexing such
as orthogonal/vectorized indexing, becomes possible for all the backend
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Improve :py:func:`~xarray.DataArray.rolling` logic.
:py:func:`~xarray.DataArrayRolling` object now supports
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view
Expand Down
2 changes: 1 addition & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .core.alignment import align, broadcast, broadcast_arrays
from .core.common import full_like, zeros_like, ones_like
from .core.combine import concat, auto_combine
from .core.computation import apply_ufunc, where
from .core.computation import apply_ufunc, where, dot
from .core.extensions import (register_dataarray_accessor,
register_dataset_accessor)
from .core.variable import as_variable, Variable, IndexVariable, Coordinate
Expand Down
104 changes: 102 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import functools
import itertools
import operator
from collections import Counter

import numpy as np

from . import duck_array_ops, utils
from . import duck_array_ops, utils, dtypes
from .alignment import deep_align
from .merge import expand_and_merge_variables
from .pycompat import OrderedDict, dask_array_type
from .pycompat import OrderedDict, dask_array_type, basestring
from .utils import is_dict_like

_DEFAULT_FROZEN_SET = frozenset()
Expand Down Expand Up @@ -926,6 +927,105 @@ def earth_mover_distance(first_samples,
return apply_array_ufunc(func, *args, dask=dask)


def dot(*arrays, **kwargs):
""" dot(*arrays, dims=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dot(*arrays, *, dims=None) is the way to write this with Python 3's keyword only arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we would keep this as dot(*arrays, **kwargs) as we did not yet drop python 2 support?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused. def dot(*arrays, *, dims=None) is not valid syntax in Python 3, either. (There can only be one single *)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP3102 says we python 3 supports the form def dot(*arrays, dim=None).


Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.

Parameters
----------
arrays: DataArray objects
Arrays to compute.
dims: str or tuple of strings, optional
Which dimensions to sum over.
If not speciified, then all the common dimensions are summed over.

Returns
-------
dot: DataArray

Examples
--------

>>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b'])
>>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5),
>>> dims=['a', 'b', 'c'])
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd'])
>>>
>>> xr.dot(da_a, da_b, dims=['a', 'b']).dims
('c', )
>>> xr.dot(da_a, da_b, dims=['a']).dims
('b', 'c')
>>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims
('a', 'd')
"""
from .dataarray import DataArray

dims = kwargs.pop('dims', None)
if len(kwargs) > 0:
raise TypeError('Invalid keyward arguments {} are given'.format(
list(kwargs.keys())))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you write xr.dot()? I suppose we still need to raise an error for 0 arguments.

if any(not isinstance(arr, DataArray) for arr in arrays):
raise TypeError('Only xr.DataArray and xr.Variable are supported.')

if isinstance(dims, basestring):
dims = [dims]

common_dims = set(arrays[0].dims)
all_dims = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it work to make all_dims a set instead of a list? I think that would be slightly more efficient.

Copy link
Member Author

@fujiisoup fujiisoup Mar 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to keep the occurrence order in all_dims, so that to move input_core_dims positions back to the original position.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.

for arr in arrays[1:]:
common_dims = common_dims.intersection(set(arr.dims))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slightly different choice of default dimensions than np.einsum:

  • np.einsum sums over any dimensions that are defined in two over more inputs.
  • This sums only over dimensions that are defined on all inputs.

Should we switch this behavior to match einsum?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be slightly more efficient to construct common_dims with a single call to intersection?

e.g.,
common_dims = set.intersection(*[set(arr.dims) for arr in arrays])

for arr in arrays:
all_dims += [d for d in arr.dims if d not in all_dims]

einsum_axes = 'abcdefghijklmnopqrstuvwxyz'
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
for arr in arrays:
dim_counts.update(arr.dims)
dims = [d for d, c in dim_counts.items() if c > 1]

broadcast_dims = [d for d in common_dims if d not in dims]
input_core_dims = []
output_core_dims = [[]]
for arr in arrays:
input_core_dims.append([d for d in arr.dims if d not in
broadcast_dims])
output_core_dims[0] += [d for d in arr.dims if d not in
output_core_dims[0] + dims + broadcast_dims]

subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds
in input_core_dims]
subscripts = ','.join(subscripts_list)
subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]])

# dtype estimation is necessary for dask='parallelized'
out_dtype = dtypes.result_type(*arrays)

# we use tensordot if possible, because it is more efficient for dask
if len(broadcast_dims) == 0 and len(arrays) == 2:
axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims]
for arr in arrays]
return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed',
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
kwargs={'axes': axes})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added a path for tensordot, which dask can compute more efficiently.


# subscripts should be passed as arg, not as a kwargs. We need
# to construct a partial function for parallelized computation.
func = functools.partial(np.einsum, subscripts)
result = apply_ufunc(func, *arrays,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
dask='parallelized', output_dtypes=[out_dtype])
return result.transpose(*[d for d in all_dims if d in result.dims])


def where(cond, x, y):
"""Return elements from `x` or `y` depending on `cond`.

Expand Down
26 changes: 7 additions & 19 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from . import duck_array_ops, groupby, indexing, ops, resample, rolling, utils
from . import computation, groupby, indexing, ops, resample, rolling, utils
from ..plot.plot import _PlotMethods
from .accessors import DatetimeAccessor
from .alignment import align, reindex_like_indexers
Expand Down Expand Up @@ -1926,7 +1926,7 @@ def real(self):
def imag(self):
return self._replace(self.variable.imag)

def dot(self, other):
def dot(self, other, dims=None):
"""Perform dot product of two DataArrays along their shared dims.

Equivalent to taking taking tensordot over all shared dims.
Expand All @@ -1935,6 +1935,9 @@ def dot(self, other):
----------
other : DataArray
The other array with which the dot product is performed.
dims: list of strings, optional
Along which dimensions to be summed over. Default all the common
dimensions are summed over.

Returns
-------
Expand All @@ -1943,6 +1946,7 @@ def dot(self, other):

See also
--------
dot
numpy.tensordot

Examples
Expand All @@ -1968,23 +1972,7 @@ def dot(self, other):
if not isinstance(other, DataArray):
raise TypeError('dot only operates on DataArrays.')

# sum over the common dims
dims = set(self.dims) & set(other.dims)
if len(dims) == 0:
raise ValueError('DataArrays have no shared dimensions over which '
'to perform dot.')

self, other = align(self, other, join='inner', copy=False)

axes = (self.get_axis_num(dims), other.get_axis_num(dims))
new_data = duck_array_ops.tensordot(self.data, other.data, axes=axes)

new_coords = self.coords.merge(other.coords)
new_coords = new_coords.drop([d for d in dims if d in new_coords])
new_dims = ([d for d in self.dims if d not in dims] +
[d for d in other.dims if d not in dims])

return type(self)(new_data, new_coords.variables, new_dims)
return computation.dot(self, other, dims=dims)

def sortby(self, variables, ascending=True):
"""
Expand Down
83 changes: 82 additions & 1 deletion xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
join_dict_keys, ordered_set_intersection, ordered_set_union, result_name,
unified_dim_sizes)

from . import raises_regex, requires_dask
from . import raises_regex, requires_dask, has_dask


def assert_identical(a, b):
Expand Down Expand Up @@ -744,6 +744,87 @@ def test_vectorize_dask():
assert_identical(expected, actual)


@pytest.mark.parametrize('dask', [True, False])
def test_dot(dask):
if not has_dask:
pytest.skip('test for dask.')

a = np.arange(30 * 4).reshape(30, 4)
b = np.arange(30 * 4 * 5).reshape(30, 4, 5)
c = np.arange(5 * 60).reshape(5, 60)
da_a = xr.DataArray(a, dims=['a', 'b'],
coords={'a': np.linspace(0, 1, 30)})
da_b = xr.DataArray(b, dims=['a', 'b', 'c'],
coords={'a': np.linspace(0, 1, 30)})
da_c = xr.DataArray(c, dims=['c', 'e'])
if dask:
da_a = da_a.chunk({'a': 3})
da_b = da_b.chunk({'a': 3})
da_c = da_c.chunk({'c': 3})

actual = xr.dot(da_a, da_b, dims=['a', 'b'])
assert actual.dims == ('c', )
assert (actual.data == np.einsum('ij,ijk->k', a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

actual = xr.dot(da_a, da_b)
assert actual.dims == ('c', )
assert (actual.data == np.einsum('ij,ijk->k', a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

# for only a single array is passed without dims argument, just return
# as is
actual = xr.dot(da_a)
assert da_a.identical(actual)

if dask:
da_a = da_a.chunk({'a': 3})
da_b = da_b.chunk({'a': 3})
actual = xr.dot(da_a, da_b, dims=['b'])
assert actual.dims == ('a', 'c')
assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all()
assert isinstance(actual.variable.data, type(da_a.variable.data))

pytest.skip('dot for dask array requires rechunking for core '
'dimensions.')

# following requires rechunking
actual = xr.dot(da_a, da_b, dims=['b'])
assert actual.dims == ('a', 'c')
assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all()

actual = xr.dot(da_a, da_b, dims='b')
assert actual.dims == ('a', 'c')
assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all()

actual = xr.dot(da_a, da_b, dims='a')
assert actual.dims == ('b', 'c')
assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all()

actual = xr.dot(da_a, da_b, dims='c')
assert actual.dims == ('a', 'b')
assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=['a', 'b'])
assert actual.dims == ('c', 'e')
assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all()

# default dims
actual = xr.dot(da_a, da_b, da_c)
assert actual.dims == ('e', )
assert (actual.data == np.einsum('ij,ijk,kl->l ', a, b, c)).all()

# 1 array summation
actual = xr.dot(da_a, dims='a')
assert actual.dims == ('b', )
assert (actual.data == np.einsum('ij->j ', a)).all()

with pytest.raises(TypeError):
actual = xr.dot(da_a, dims='a', invalid=None)
with pytest.raises(TypeError):
actual = xr.dot(da_a.to_dataset(name='da'), dims='a')


def test_where():
cond = xr.DataArray([True, False], dims='x')
actual = xr.where(cond, 1, 0)
Expand Down
2 changes: 0 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3200,8 +3200,6 @@ def test_dot(self):
da.dot(dm.to_dataset(name='dm'))
with pytest.raises(TypeError):
da.dot(dm.values)
with raises_regex(ValueError, 'no shared dimensions'):
da.dot(DataArray(1))

def test_binary_op_join_setting(self):
dim = 'x'
Expand Down