-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
map_blocks: Allow passing dask-backed objects in args #3818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
a96e759
5c17e3f
74cb118
639ad69
748828e
f73491a
4029649
d2f2916
32a37c7
a0e699f
4d40a25
04ffa6c
ed1bbab
d28ea75
4937bfc
5b8cad6
7908680
2bdcc64
552571c
10427bb
ba522e0
db9fa9f
5644c65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ | |
DefaultDict, | ||
Dict, | ||
Hashable, | ||
Iterable, | ||
List, | ||
Mapping, | ||
Sequence, | ||
Tuple, | ||
|
@@ -25,12 +27,29 @@ | |
|
||
import numpy as np | ||
|
||
from .alignment import align | ||
from .dataarray import DataArray | ||
from .dataset import Dataset | ||
|
||
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) | ||
|
||
|
||
def to_object_array(iterable): | ||
npargs = np.empty((len(iterable),), dtype=np.object) | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for idx, item in enumerate(iterable): | ||
npargs[idx] = item | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return npargs | ||
|
||
|
||
def assert_chunks_compatible(a: Dataset, b: Dataset): | ||
a = a.unify_chunks() | ||
b = b.unify_chunks() | ||
|
||
for dim in set(a.chunks).intersection(set(b.chunks)): | ||
if a.chunks[dim] != b.chunks[dim]: | ||
raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.") | ||
|
||
|
||
def check_result_variables( | ||
result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str | ||
): | ||
|
@@ -67,6 +86,17 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray: | |
return next(iter(obj.data_vars.values())) | ||
|
||
|
||
def dataarray_to_dataset(obj: DataArray) -> Dataset: | ||
# only using _to_temp_dataset would break | ||
# func = lambda x: x.to_dataset() | ||
# since that relies on preserving name. | ||
if obj.name is None: | ||
dataset = obj._to_temp_dataset() | ||
else: | ||
dataset = obj.to_dataset() | ||
return dataset | ||
|
||
|
||
def make_meta(obj): | ||
"""If obj is a DataArray or Dataset, return a new object of the same type and with | ||
the same variables and dtypes, but where all variables have size 0 and numpy | ||
|
@@ -161,8 +191,8 @@ def map_blocks( | |
obj: DataArray, Dataset | ||
Passed to the function as its first argument, one dask chunk at a time. | ||
args: Sequence | ||
Passed verbatim to func after unpacking, after the sliced obj. xarray objects, | ||
if any, will not be split by chunks. Passing dask collections is not allowed. | ||
Passed verbatim to func after unpacking, after the sliced obj. | ||
Any xarray objects will also be split by blocks and then passed on. | ||
kwargs: Mapping | ||
Passed verbatim to func after unpacking. xarray objects, if any, will not be | ||
split by chunks. Passing dask collections is not allowed. | ||
|
@@ -241,14 +271,27 @@ def map_blocks( | |
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 | ||
""" | ||
|
||
def _wrapper(func, obj, to_array, args, kwargs, expected): | ||
check_shapes = dict(obj.dims) | ||
def _wrapper( | ||
func: Callable, | ||
args: List, | ||
kwargs: dict, | ||
arg_is_array: Iterable[bool], | ||
expected: dict, | ||
): | ||
""" | ||
Wrapper function that receives datasets in args; converts to dataarrays when necessary; | ||
passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. | ||
""" | ||
|
||
check_shapes = dict(args[0].dims) | ||
check_shapes.update(expected["shapes"]) | ||
|
||
if to_array: | ||
obj = dataset_to_dataarray(obj) | ||
converted_args = [ | ||
dataset_to_dataarray(arg) if is_array else arg | ||
for is_array, arg in zip(arg_is_array, args) | ||
] | ||
|
||
result = func(obj, *args, **kwargs) | ||
result = func(*converted_args, **kwargs) | ||
|
||
# check all dims are present | ||
missing_dimensions = set(expected["shapes"]) - set(result.sizes) | ||
|
@@ -289,52 +332,57 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
elif not isinstance(kwargs, Mapping): | ||
raise TypeError("kwargs must be a mapping (for example, a dict)") | ||
|
||
for value in list(args) + list(kwargs.values()): | ||
for value in kwargs.values(): | ||
if dask.is_dask_collection(value): | ||
raise TypeError( | ||
"Cannot pass dask collections in args or kwargs yet. Please compute or " | ||
"Cannot pass dask collections in kwargs yet. Please compute or " | ||
"load values before passing to map_blocks." | ||
) | ||
|
||
if not dask.is_dask_collection(obj): | ||
return func(obj, *args, **kwargs) | ||
|
||
if isinstance(obj, DataArray): | ||
# only using _to_temp_dataset would break | ||
# func = lambda x: x.to_dataset() | ||
# since that relies on preserving name. | ||
if obj.name is None: | ||
dataset = obj._to_temp_dataset() | ||
else: | ||
dataset = obj.to_dataset() | ||
input_is_array = True | ||
else: | ||
dataset = obj | ||
input_is_array = False | ||
npargs = to_object_array([obj] + list(args)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. converting to object array so that we can use boolean indexing to pull out xarray objects |
||
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] | ||
is_array = [isinstance(arg, DataArray) for arg in npargs] | ||
|
||
# align all xarray objects | ||
# TODO: should we allow join as a kwarg or force everything to be aligned to the first object? | ||
aligned = align(*npargs[is_xarray], join="left") | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# assigning to object arrays works better when RHS is object array | ||
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array | ||
npargs[is_xarray] = to_object_array(aligned) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a better way to do this assignment?
|
||
npargs[is_array] = to_object_array( | ||
[dataarray_to_dataset(da) for da in npargs[is_array]] | ||
) | ||
|
||
# check that chunk sizes are compatible | ||
input_chunks = dict(npargs[0].chunks) | ||
input_indexes = dict(npargs[0].indexes) | ||
for arg in npargs[1:][is_xarray[1:]]: | ||
assert_chunks_compatible(npargs[0], arg) | ||
input_chunks.update(arg.chunks) | ||
input_indexes.update(arg.indexes) | ||
|
||
input_chunks = dataset.chunks | ||
dataset_indexes = set(dataset.indexes) | ||
if template is None: | ||
# infer template by providing zero-shaped arrays | ||
template = infer_template(func, obj, *args, **kwargs) | ||
template = infer_template(func, aligned[0], *args, **kwargs) | ||
template_indexes = set(template.indexes) | ||
preserved_indexes = template_indexes & dataset_indexes | ||
new_indexes = template_indexes - dataset_indexes | ||
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} | ||
preserved_indexes = template_indexes & set(input_indexes) | ||
new_indexes = template_indexes - set(input_indexes) | ||
indexes = {dim: input_indexes[dim] for dim in preserved_indexes} | ||
indexes.update({k: template.indexes[k] for k in new_indexes}) | ||
output_chunks = { | ||
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks | ||
} | ||
|
||
else: | ||
# template xarray object has been provided with proper sizes and chunk shapes | ||
template_indexes = set(template.indexes) | ||
indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} | ||
indexes.update({k: template.indexes[k] for k in template_indexes}) | ||
indexes = dict(template.indexes) | ||
if isinstance(template, DataArray): | ||
output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore | ||
else: | ||
output_chunks = template.chunks # type: ignore | ||
output_chunks = dict(template.chunks) | ||
|
||
for dim in output_chunks: | ||
if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): | ||
|
@@ -363,7 +411,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
graph: Dict[Any, Any] = {} | ||
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) | ||
gname = "{}-{}".format( | ||
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) | ||
dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) | ||
) | ||
|
||
# map dims to list of chunk indexes | ||
|
@@ -376,17 +424,23 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() | ||
} | ||
|
||
# iterate over all possible chunk combinations | ||
for v in itertools.product(*ichunk.values()): | ||
chunk_index = dict(zip(dataset.dims, v)) | ||
def subset_dataset_to_block( | ||
graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index | ||
): | ||
""" | ||
Creates a task that creates a subsets xarray dataset to a block determined by chunk_index; | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
whose extents are determined by input_chunk_bounds. | ||
There are subtasks that create subsets of constituent variables. | ||
""" | ||
|
||
# this will become [[name1, variable1], | ||
# [name2, variable2], | ||
# ...] | ||
# [name2, variable2], | ||
# ...] | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# which is passed to dict and then to Dataset | ||
data_vars = [] | ||
coords = [] | ||
|
||
chunk_tuple = tuple(chunk_index.values()) | ||
for name, variable in dataset.variables.items(): | ||
# make a task that creates tuple of (dims, chunk) | ||
if dask.is_dask_collection(variable.data): | ||
|
@@ -395,13 +449,13 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
for dim in variable.dims: | ||
chunk = chunk[chunk_index[dim]] | ||
|
||
chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v | ||
chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple | ||
graph[chunk_variable_task] = ( | ||
tuple, | ||
[variable.dims, chunk, variable.attrs], | ||
) | ||
else: | ||
# non-dask array with possibly chunked dimensions | ||
# non-dask array possibly with dimensions chunked on other variables | ||
# index into variable appropriately | ||
subsetter = { | ||
dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) | ||
|
@@ -410,7 +464,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
subset = variable.isel(subsetter) | ||
chunk_variable_task = ( | ||
"{}-{}".format(gname, dask.base.tokenize(subset)), | ||
) + v | ||
) + chunk_tuple | ||
graph[chunk_variable_task] = ( | ||
tuple, | ||
[subset.dims, subset, subset.attrs], | ||
|
@@ -422,7 +476,22 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
else: | ||
data_vars.append([name, chunk_variable_task]) | ||
|
||
# expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper | ||
return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) | ||
|
||
# iterate over all possible chunk combinations | ||
for chunk_tuple in itertools.product(*ichunk.values()): | ||
# mapping from dimension name to chunk index | ||
chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) | ||
|
||
blocked_args = [ | ||
subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index) | ||
if isxr | ||
else arg | ||
for isxr, arg in zip(is_xarray, npargs) | ||
] | ||
|
||
# expected["shapes", "coords", "data_vars", "indexes"] are used to | ||
# raise nice error messages in _wrapper | ||
expected = {} | ||
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension | ||
# even if length of dimension is changed by the applied function | ||
|
@@ -436,16 +505,8 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
for dim in indexes | ||
} | ||
|
||
from_wrapper = (gname,) + v | ||
graph[from_wrapper] = ( | ||
_wrapper, | ||
func, | ||
(Dataset, (dict, data_vars), (dict, coords), dataset.attrs), | ||
input_is_array, | ||
args, | ||
kwargs, | ||
expected, | ||
) | ||
from_wrapper = (gname,) + chunk_tuple | ||
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) | ||
|
||
# mapping from variable name to dask graph key | ||
var_key_map: Dict[Hashable, str] = {} | ||
|
@@ -472,14 +533,22 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
# layer. | ||
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) | ||
|
||
hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) | ||
hlg = HighLevelGraph.from_collections( | ||
gname, | ||
graph, | ||
dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], | ||
) | ||
|
||
for gname_l, layer in new_layers.items(): | ||
# This adds in the getitems for each variable in the dataset. | ||
hlg.dependencies[gname_l] = {gname} | ||
hlg.layers[gname_l] = layer | ||
|
||
result = Dataset(coords=indexes, attrs=template.attrs) | ||
for index in result.indexes: | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result[index].attrs = template[index].attrs | ||
result[index].encoding = template[index].encoding | ||
|
||
for name, gname_l in var_key_map.items(): | ||
dims = template[name].dims | ||
var_chunks = [] | ||
|
@@ -496,6 +565,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): | |
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype | ||
) | ||
result[name] = (dims, data, template[name].attrs) | ||
result[name].encoding = template[name].encoding | ||
|
||
result = result.set_coords(template._coord_names) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.