Skip to content

dask.optimize on xarray objects #3698

@dcherian

Description

@dcherian

I am trying to call dask.optimize on a xarray object before the graph gets too big. But get weird errors. Simple examples below. All examples work if I remove the dask.optimize step.

cc @mrocklin @shoyer

This works with dask arrays:

a = dask.array.ones((10,5), chunks=(1,3))
a = dask.optimize(a)[0]
a.compute()

It works when a dataArray is constructed using a dask array

da = xr.DataArray(a)
da = dask.optimize(da)[0]
da.compute()

but fails when creating a DataArray with a numpy array and then chunking it

🤷‍♂️

da = xr.DataArray(a.compute()).chunk({"dim_0": 5})
da = dask.optimize(da)[0]
da.compute()

fails with error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-50-1f16efa19800> in <module>
      1 da = xr.DataArray(a.compute()).chunk({"dim_0": 5})
      2 da = dask.optimize(da)[0]
----> 3 da.compute()

~/python/xarray/xarray/core/dataarray.py in compute(self, **kwargs)
    838         """
    839         new = self.copy(deep=False)
--> 840         return new.load(**kwargs)
    841 
    842     def persist(self, **kwargs) -> "DataArray":

~/python/xarray/xarray/core/dataarray.py in load(self, **kwargs)
    812         dask.array.compute
    813         """
--> 814         ds = self._to_temp_dataset().load(**kwargs)
    815         new = self._from_temp_dataset(ds)
    816         self._variable = new._variable

~/python/xarray/xarray/core/dataset.py in load(self, **kwargs)
    659 
    660             # evaluate all the dask arrays simultaneously
--> 661             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    662 
    663             for k, data in zip(lazy_data, evaluated_data):

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     79         get_id=_thread_get_id,
     80         pack_exception=pack_exception,
---> 81         **kwargs
     82     )
     83 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

TypeError: string indices must be integers

And a different error when rechunking a dask-backed DataArray

da = xr.DataArray(a).chunk({"dim_0": 5})
da = dask.optimize(da)[0]
da.compute()
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-55-d978bbb9e38d> in <module>
      1 da = xr.DataArray(a).chunk({"dim_0": 5})
      2 da = dask.optimize(da)[0]
----> 3 da.compute()

~/python/xarray/xarray/core/dataarray.py in compute(self, **kwargs)
    838         """
    839         new = self.copy(deep=False)
--> 840         return new.load(**kwargs)
    841 
    842     def persist(self, **kwargs) -> "DataArray":

~/python/xarray/xarray/core/dataarray.py in load(self, **kwargs)
    812         dask.array.compute
    813         """
--> 814         ds = self._to_temp_dataset().load(**kwargs)
    815         new = self._from_temp_dataset(ds)
    816         self._variable = new._variable

~/python/xarray/xarray/core/dataset.py in load(self, **kwargs)
    659 
    660             # evaluate all the dask arrays simultaneously
--> 661             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    662 
    663             for k, data in zip(lazy_data, evaluated_data):

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     79         get_id=_thread_get_id,
     80         pack_exception=pack_exception,
---> 81         **kwargs
     82     )
     83 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in concatenate3(arrays)
   4305     if not ndim:
   4306         return arrays
-> 4307     chunks = chunks_from_arrays(arrays)
   4308     shape = tuple(map(sum, chunks))
   4309 

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in chunks_from_arrays(arrays)
   4085 
   4086     while isinstance(arrays, (list, tuple)):
-> 4087         result.append(tuple([shape(deepfirst(a))[dim] for a in arrays]))
   4088         arrays = arrays[0]
   4089         dim += 1

~/miniconda3/envs/dcpy_updated/lib/python3.7/site-packages/dask/array/core.py in <listcomp>(.0)
   4085 
   4086     while isinstance(arrays, (list, tuple)):
-> 4087         result.append(tuple([shape(deepfirst(a))[dim] for a in arrays]))
   4088         arrays = arrays[0]
   4089         dim += 1

IndexError: tuple index out of range

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions