diff --git a/doc/api.rst b/doc/api.rst index ae4803e5e62..1814b874b3e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -24,6 +24,7 @@ Top-level functions full_like zeros_like ones_like + dot Dataset ======= diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 963a0454f88..e60d98340c9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,10 @@ Documentation Enhancements ~~~~~~~~~~~~ +- Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``. + Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, + which specifies the dimensions to sum over. + (:issue:`1951`) - Support for writing xarray datasets to netCDF files (netcdf4 backend only) when using the `dask.distributed `_ scheduler (:issue:`1464`). @@ -49,7 +53,6 @@ Enhancements as orthogonal/vectorized indexing, becomes possible for all the backend arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`) By `Keisuke Fujii `_. - - Improve :py:func:`~xarray.DataArray.rolling` logic. :py:func:`~xarray.DataArrayRolling` object now supports :py:func:`~xarray.DataArrayRolling.construct` method that returns a view diff --git a/xarray/__init__.py b/xarray/__init__.py index 3e80acd1572..1a2bf3fe283 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -6,7 +6,7 @@ from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine -from .core.computation import apply_ufunc, where +from .core.computation import apply_ufunc, where, dot from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 858936aad6c..685a3c66c54 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -6,13 +6,14 @@ import functools import itertools import operator +from collections import Counter import numpy as np -from . import duck_array_ops, utils +from . import duck_array_ops, utils, dtypes from .alignment import deep_align from .merge import expand_and_merge_variables -from .pycompat import OrderedDict, dask_array_type +from .pycompat import OrderedDict, dask_array_type, basestring from .utils import is_dict_like _DEFAULT_FROZEN_SET = frozenset() @@ -937,6 +938,111 @@ def earth_mover_distance(first_samples, return apply_array_ufunc(func, *args, dask=dask) +def dot(*arrays, **kwargs): + """ dot(*arrays, dims=None) + + Generalized dot product for xarray objects. Like np.einsum, but + provides a simpler interface based on array dimensions. + + Parameters + ---------- + arrays: DataArray (or Variable) objects + Arrays to compute. + dims: str or tuple of strings, optional + Which dimensions to sum over. + If not speciified, then all the common dimensions are summed over. + + Returns + ------- + dot: DataArray + + Examples + -------- + + >>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) + >>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), + >>> dims=['a', 'b', 'c']) + >>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) + >>> + >>> xr.dot(da_a, da_b, dims=['a', 'b']).dims + ('c', ) + >>> xr.dot(da_a, da_b, dims=['a']).dims + ('b', 'c') + >>> xr.dot(da_a, da_b, da_c, dims=['b', 'c']).dims + ('a', 'd') + """ + from .dataarray import DataArray + from .variable import Variable + + dims = kwargs.pop('dims', None) + if len(kwargs) > 0: + raise TypeError('Invalid keyward arguments {} are given'.format( + list(kwargs.keys()))) + + if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): + raise TypeError('Only xr.DataArray and xr.Variable are supported.' + 'Given {}.'.format([type(arr) for arr in arrays])) + + if len(arrays) == 0: + raise TypeError('At least one array should be given.') + + if isinstance(dims, basestring): + dims = (dims, ) + + common_dims = set.intersection(*[set(arr.dims) for arr in arrays]) + all_dims = [] + for arr in arrays: + all_dims += [d for d in arr.dims if d not in all_dims] + + einsum_axes = 'abcdefghijklmnopqrstuvwxyz' + dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} + + if dims is None: + # find dimensions that occur more than one times + dim_counts = Counter() + for arr in arrays: + dim_counts.update(arr.dims) + dims = tuple(d for d, c in dim_counts.items() if c > 1) + + dims = tuple(dims) # make dims a tuple + + # dimensions to be parallelized + broadcast_dims = tuple(d for d in all_dims + if d in common_dims and d not in dims) + input_core_dims = [[d for d in arr.dims if d not in broadcast_dims] + for arr in arrays] + output_core_dims = [tuple(d for d in all_dims if d not in + dims + broadcast_dims)] + + # we use tensordot if possible, because it is more efficient for dask + if len(broadcast_dims) == 0 and len(arrays) == 2: + axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] + for arr in arrays] + return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + kwargs={'axes': axes}) + + # construct einsum subscripts, such as '...abc,...ab->...c' + # Note: input_core_dims are always moved to the last position + subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds + in input_core_dims] + subscripts = ','.join(subscripts_list) + subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) + + # dtype estimation is necessary for dask='parallelized' + out_dtype = dtypes.result_type(*arrays) + + # subscripts should be passed to np.einsum as arg, not as kwargs. We need + # to construct a partial function for parallelized computation. + func = functools.partial(np.einsum, subscripts) + result = apply_ufunc(func, *arrays, + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + dask='parallelized', output_dtypes=[out_dtype]) + return result.transpose(*[d for d in all_dims if d in result.dims]) + + def where(cond, x, y): """Return elements from `x` or `y` depending on `cond`. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8c0360df8a9..3c022752174 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from . import duck_array_ops, groupby, indexing, ops, resample, rolling, utils +from . import computation, groupby, indexing, ops, resample, rolling, utils from ..plot.plot import _PlotMethods from .accessors import DatetimeAccessor from .alignment import align, reindex_like_indexers @@ -1926,7 +1926,7 @@ def real(self): def imag(self): return self._replace(self.variable.imag) - def dot(self, other): + def dot(self, other, dims=None): """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -1935,6 +1935,9 @@ def dot(self, other): ---------- other : DataArray The other array with which the dot product is performed. + dims: list of strings, optional + Along which dimensions to be summed over. Default all the common + dimensions are summed over. Returns ------- @@ -1943,6 +1946,7 @@ def dot(self, other): See also -------- + dot numpy.tensordot Examples @@ -1968,23 +1972,7 @@ def dot(self, other): if not isinstance(other, DataArray): raise TypeError('dot only operates on DataArrays.') - # sum over the common dims - dims = set(self.dims) & set(other.dims) - if len(dims) == 0: - raise ValueError('DataArrays have no shared dimensions over which ' - 'to perform dot.') - - self, other = align(self, other, join='inner', copy=False) - - axes = (self.get_axis_num(dims), other.get_axis_num(dims)) - new_data = duck_array_ops.tensordot(self.data, other.data, axes=axes) - - new_coords = self.coords.merge(other.coords) - new_coords = new_coords.drop([d for d in dims if d in new_coords]) - new_dims = ([d for d in self.dims if d not in dims] + - [d for d in other.dims if d not in dims]) - - return type(self)(new_data, new_coords.variables, new_dims) + return computation.dot(self, other, dims=dims) def sortby(self, variables, ascending=True): """ diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index ebd51d04857..88710e55091 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -14,7 +14,7 @@ join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, unified_dim_sizes) -from . import raises_regex, requires_dask +from . import raises_regex, requires_dask, has_dask def assert_identical(a, b): @@ -744,6 +744,100 @@ def test_vectorize_dask(): assert_identical(expected, actual) +@pytest.mark.parametrize('dask', [True, False]) +def test_dot(dask): + if not has_dask: + pytest.skip('test for dask.') + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + c = np.arange(5 * 60).reshape(5, 60) + da_a = xr.DataArray(a, dims=['a', 'b'], + coords={'a': np.linspace(0, 1, 30)}) + da_b = xr.DataArray(b, dims=['a', 'b', 'c'], + coords={'a': np.linspace(0, 1, 30)}) + da_c = xr.DataArray(c, dims=['c', 'e']) + if dask: + da_a = da_a.chunk({'a': 3}) + da_b = da_b.chunk({'a': 3}) + da_c = da_c.chunk({'c': 3}) + + actual = xr.dot(da_a, da_b, dims=['a', 'b']) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + actual = xr.dot(da_a, da_b) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + # for only a single array is passed without dims argument, just return + # as is + actual = xr.dot(da_a) + assert da_a.identical(actual) + + # test for variable + actual = xr.dot(da_a.variable, da_b.variable) + assert actual.dims == ('c', ) + assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() + assert isinstance(actual.data, type(da_a.variable.data)) + + if dask: + da_a = da_a.chunk({'a': 3}) + da_b = da_b.chunk({'a': 3}) + actual = xr.dot(da_a, da_b, dims=['b']) + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + assert isinstance(actual.variable.data, type(da_a.variable.data)) + + pytest.skip('dot for dask array requires rechunking for core ' + 'dimensions.') + + # following requires rechunking + actual = xr.dot(da_a, da_b, dims=['b']) + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + + actual = xr.dot(da_a, da_b, dims='b') + assert actual.dims == ('a', 'c') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + + actual = xr.dot(da_a, da_b, dims='a') + assert actual.dims == ('b', 'c') + assert (actual.data == np.einsum('ij,ijk->jk', a, b)).all() + + actual = xr.dot(da_a, da_b, dims='c') + assert actual.dims == ('a', 'b') + assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + + actual = xr.dot(da_a, da_b, da_c, dims=['a', 'b']) + assert actual.dims == ('c', 'e') + assert (actual.data == np.einsum('ij,ijk,kl->kl ', a, b, c)).all() + + # should work with tuple + actual = xr.dot(da_a, da_b, dims=('c', )) + assert actual.dims == ('a', 'b') + assert (actual.data == np.einsum('ij,ijk->ij', a, b)).all() + + # default dims + actual = xr.dot(da_a, da_b, da_c) + assert actual.dims == ('e', ) + assert (actual.data == np.einsum('ij,ijk,kl->l ', a, b, c)).all() + + # 1 array summation + actual = xr.dot(da_a, dims='a') + assert actual.dims == ('b', ) + assert (actual.data == np.einsum('ij->j ', a)).all() + + with pytest.raises(TypeError): + actual = xr.dot(da_a, dims='a', invalid=None) + with pytest.raises(TypeError): + actual = xr.dot(da_a.to_dataset(name='da'), dims='a') + with pytest.raises(TypeError): + actual = xr.dot(dims='a') + + def test_where(): cond = xr.DataArray([True, False], dims='x') actual = xr.where(cond, 1, 0) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index f42df1cbabb..059e93fc70c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3200,8 +3200,6 @@ def test_dot(self): da.dot(dm.to_dataset(name='dm')) with pytest.raises(TypeError): da.dot(dm.values) - with raises_regex(ValueError, 'no shared dimensions'): - da.dot(DataArray(1)) def test_binary_op_join_setting(self): dim = 'x'