Description
Code Sample, a copy-pastable example if possible
I am using xarray
in combination to dask distributed
on a cluster, so a mimimal code sample demonstrating my problem is not easy to come up with.
Problem description
Here is what I observe:
-
In my environment,
dask distributed
is correctly set-up with auto-scaling. I can verify this by loading data intoxarray
and using aggregation functions likemean()
. This triggers auto-scaling and the dask dashboard shows that the processing is spread accross slave nodes. -
I have the following
xarray
dataset calledgeoms_ds
:
<xarray.Dataset>
Dimensions: (x: 10980, y: 10980)
Coordinates:
* y (y) float64 4.9e+06 4.9e+06 4.9e+06 ... 4.79e+06 4.79e+06 4.79e+06
* x (x) float64 3e+05 3e+05 3e+05 ... 4.098e+05 4.098e+05 4.098e+05
Data variables:
label (y, x) uint16 dask.array<shape=(10980, 10980), chunksize=(200, 10980)>
Which I load with the following code sample:
import xarray as xr
geoms = xr.open_rasterio('test_rasterization_T31TCJ_uint16.tif',chunks={'band': 1, 'x': 10980, 'y': 200})
geoms_squeez = geoms.isel(band=0).squeeze().drop(labels='band')
geoms_ds = geoms_squeez.to_dataset(name='label')
This array
holds a finite number of integer values denoting groups (or classes if you like). I would like to perform statistics on groups (with additional variables) such as the mean value of a given variable for each group for instance.
-
I can do this perfectly for a single group using
.where(label=xxx).mean('variable')
, this behaves as expected, triggering auto-scaling and dask graph of task. -
The problem is that I have a lot of groups (or classes) and looping through all of them and apply
where()
is not very efficient. From my reading ofxarray
documentation,groupby
is what I need, to perform stats on all groups at once. -
When I try to use
geoms_ds.groupby('label').size()
for instance, here is what I observe:
- Grouping is not lazy, it is evaluated immediately,
- Grouping is not performed through dask distributed, only the master node is working, on a single thread,
- The grouping operation takes a large amount of time and eats a large amount of memory (nearly 30 Gb, which is a lot more than what is required to store the full dataset in memory)
- Most of the time, the grouping fail with the following errors and warnings:
distributed.utils_perf - WARNING - full garbage collections took 52% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 47% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 48% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 50% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 53% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 56% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 56% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 57% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 58% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 58% CPU time recently (threshold: 10%)
distributed.utils_perf - WARNING - full garbage collections took 59% CPU time recently (threshold: 10%)
WARNING:dask_jobqueue.core:Worker tcp://10.135.39.92:51747 restart in Job 2758934. This can be due to memory issue.
distributed.utils - ERROR - 'tcp://10.135.39.92:51747'
Traceback (most recent call last):
File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/utils.py", line 648, in log_errors
yield
File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/scheduler.py", line 1360, in add_worker
yield self.handle_worker(comm=comm, worker=address)
File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/tornado/gen.py", line 1133, in run
value = future.result()
File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
yielded = next(result)
File "/work/logiciels/projets/eolab/conda/eolab/lib/python3.6/site-packages/distributed/scheduler.py", line 2220, in handle_worker
worker_comm = self.stream_comms[worker]
KeyError: ...
Which I assume comes from the fact that the process is killed by pbs for excessive memory usage.
Expected Output
I would except the following:
- Single call to
groupby
lazily evaluated, - Evaluation of aggregation function performed through
dask distributed
- The dataset is not so large, even on a single master thread the computation should end well in reasonable time.
Output of xr.show_versions()
xarray: 0.11.3
pandas: 0.24.1
numpy: 1.16.1
scipy: 1.2.0
netCDF4: 1.4.2
pydap: None
h5netcdf: None
h5py: None
Nio: None
zarr: None
cftime: 1.0.3.4
PseudonetCDF: None
rasterio: 1.0.15
cfgrib: None
iris: None
bottleneck: None
cyordereddict: None
dask: 1.1.1
distributed: 1.25.3
matplotlib: 3.0.2
cartopy: 0.17.0
seaborn: 0.9.0
setuptools: 40.7.1
pip: 19.0.1
conda: None
pytest: None
IPython: 7.1.1
sphinx: None