-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
einsum for xarray #1968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einsum for xarray #1968
Changes from 4 commits
220ebcc
4239ac6
0f472a2
c83d442
1c732a4
b8d93b0
3278bf3
1ec5683
789cb96
a57907c
693b242
88be319
b3d4768
2bd06ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ Top-level functions | |
full_like | ||
zeros_like | ||
ones_like | ||
dot | ||
|
||
Dataset | ||
======= | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,10 +9,10 @@ | |
|
||
import numpy as np | ||
|
||
from . import duck_array_ops, utils | ||
from . import duck_array_ops, utils, dtypes | ||
from .alignment import deep_align | ||
from .merge import expand_and_merge_variables | ||
from .pycompat import OrderedDict, dask_array_type | ||
from .pycompat import OrderedDict, dask_array_type, basestring | ||
from .utils import is_dict_like | ||
|
||
_DEFAULT_FROZEN_SET = frozenset() | ||
|
@@ -926,6 +926,103 @@ def earth_mover_distance(first_samples, | |
return apply_array_ufunc(func, *args, dask=dask) | ||
|
||
|
||
def dot(*arrays, **kwargs): | ||
""" dot(*arrays, *, dims=None) | ||
|
||
einsum for xarray object, but providing simpler interface based on | ||
the array dimensions. | ||
|
||
Parameters | ||
---------- | ||
arrays: multiple DataArrays | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
arrays to compute. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Arrays |
||
dims: tuple of strings, optional | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. str or tuple of strings |
||
Along which dimensions to be summed over. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which dimensions to sum over. |
||
If not speciified, then all the common dimensions are summed over. | ||
|
||
Returns | ||
------- | ||
dot: same type to input. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably should just "DataArray"? |
||
|
||
Examples | ||
-------- | ||
|
||
>>> da_a = xr.DataArray(np.arange(3 * 4).reshape(3, 4), dims=['a', 'b']) | ||
>>> da_b = xr.DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5), | ||
dims=['a', 'b', 'c']) | ||
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) | ||
|
||
>>> dot(da_a, da_b, dims=['a', 'b']).dims | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should use the full name |
||
('c', ) | ||
>>> dot(da_a, da_b, dims=['a']).dims | ||
('b', 'c') | ||
>>> dot(da_a, da_b, da_c, dims=['b', 'c']).dims | ||
('a', 'd') | ||
""" | ||
from .dataarray import DataArray | ||
from .variable import Variable | ||
|
||
dims = kwargs.pop('dims', None) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if you write |
||
if len(arrays) < 2: | ||
raise TypeError('More than one arrays must be provided') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this special case? If not, let's remove this. For consistency, it is nice to use the same logic even for edge cases when possible. This makes it easier to think about the function. In this case, I think a dot product of 1 array would consistently defined by summing over dimensions listed explicitly in |
||
|
||
if any(not isinstance(arr, (DataArray, Variable)) for arr in arrays): | ||
raise TypeError('Only xr.DataArray and xr.Variable are supported.') | ||
|
||
if isinstance(dims, basestring): | ||
dims = [dims] | ||
|
||
common_dims = set(arrays[0].dims) | ||
for arr in arrays[1:]: | ||
common_dims = common_dims.intersection(set(arr.dims)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a slightly different choice of default dimensions than
Should we switch this behavior to match There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be slightly more efficient to construct e.g., |
||
|
||
if dims is None: | ||
dims = list(common_dims) | ||
|
||
broadcast_dims = [d for d in common_dims if d not in dims] | ||
|
||
input_core_dims = [] | ||
output_core_dims = [[]] | ||
all_dims = [] | ||
|
||
for arr in arrays: | ||
input_core_dims.append([d for d in arr.dims if d not in | ||
broadcast_dims]) | ||
output_core_dims[0] += [d for d in arr.dims if d not in | ||
output_core_dims[0] + dims + broadcast_dims] | ||
all_dims += [d for d in arr.dims if d not in all_dims] | ||
|
||
einsum_axes = 'abcdefghijklmnopqrstuvwxyz' | ||
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} | ||
|
||
subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds | ||
in input_core_dims] | ||
subscripts = ','.join(subscripts_list) | ||
subscripts += '->...' + ''.join([dim_map[d] for d in output_core_dims[0]]) | ||
|
||
# dtype estimation is necessary for dask='parallelized' | ||
out_dtype = dtypes.result_type(*arrays) | ||
|
||
# we use tensordot if available, because it is more efficient for dask | ||
if len(broadcast_dims) == 0 and len(arrays) == 2: | ||
axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] | ||
for arr in arrays] | ||
return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
kwargs={'axes': axes}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I added a path for tensordot, which dask can compute more efficiently. |
||
|
||
# subscripts should be passed as arg, instead of kwargs. We need | ||
# to pass a partial function especially for parallelized computation. | ||
func = functools.partial(np.einsum, subscripts) | ||
result = apply_ufunc(func, *arrays, | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
dask='parallelized', output_dtypes=[out_dtype]) | ||
return result.transpose(*[d for d in all_dims if d in result.dims]) | ||
|
||
|
||
def where(cond, x, y): | ||
"""Return elements from `x` or `y` depending on `cond`. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should lead with a more general description. Maybe: