From eaaee7bca123beaa296ece820fa742c6ec2836b2 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:02:37 +0100 Subject: [PATCH 1/3] Reduce graph size through writing indexes directly into graph for map_blocks --- xarray/core/parallel.py | 27 +++++++++++++++++---------- xarray/tests/test_dask.py | 7 +++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 84728007b42..1de1c01ac7b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -25,7 +25,6 @@ class ExpectedDict(TypedDict): shapes: dict[Hashable, int] coords: set[Hashable] data_vars: set[Hashable] - indexes: dict[Hashable, Index] def unzip(iterable): @@ -337,6 +336,7 @@ def _wrapper( kwargs: dict, arg_is_array: Iterable[bool], expected: ExpectedDict, + expected_indexes: dict[Hashable, Index], ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -372,7 +372,7 @@ def _wrapper( # ChainMap wants MutableMapping, but xindexes is Mapping merged_indexes = collections.ChainMap( - expected["indexes"], + expected_indexes, merged_coordinates.xindexes, # type: ignore[arg-type] ) expected_index = merged_indexes.get(name, None) @@ -413,6 +413,7 @@ def _wrapper( import dask import dask.array from dask.highlevelgraph import HighLevelGraph + from dask.base import tokenize except ImportError: pass @@ -551,6 +552,19 @@ def _wrapper( for isxr, arg in zip(is_xarray, npargs, strict=True) ] + indexes = { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in (new_indexes | modified_indexes) + } + + tokenized_indexes: dict[Hashable, str] = {} + for k, v in indexes.items(): + tokenized_v = tokenize(v) + graph[f"{k}-coordinate-{tokenized_v}"] = v + tokenized_indexes[k] = f"{k}-coordinate-{tokenized_v}" + # raise nice error messages in _wrapper expected: ExpectedDict = { # input chunk 0 along a dimension maps to output chunk 0 along the same dimension @@ -562,17 +576,10 @@ def _wrapper( }, "data_vars": set(template.data_vars.keys()), "coords": set(template.coords.keys()), - # only include new or modified indexes to minimize duplication of data, and graph size. - "indexes": { - dim: coordinates.xindexes[dim][ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in (new_indexes | modified_indexes) - }, } from_wrapper = (gname,) + chunk_tuple - graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected, (dict, [[k, v] for k, v in tokenized_indexes.items()])) # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index a46a9d43c4c..d331976e367 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -9,11 +9,13 @@ import numpy as np import pandas as pd import pytest +from distributed import LocalCluster import xarray as xr from xarray import DataArray, Dataset, Variable from xarray.core import duck_array_ops from xarray.core.duck_array_ops import lazy_array_equiv +from xarray.core.indexes import PandasIndex from xarray.testing import assert_chunks_equal from xarray.tests import ( assert_allclose, @@ -1375,6 +1377,11 @@ def test_map_blocks_da_ds_with_template(obj): actual = xr.map_blocks(func, obj, template=template) assert_identical(actual, template) + # Check that indexes are written into the graph directly + dsk = dict(actual.__dask_graph__()) + assert len({k for k in dsk if "x-coordinate" in k}) + assert all(isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k) + with raise_if_dask_computes(): actual = obj.map_blocks(func, template=template) assert_identical(actual, template) From e7ff50bb567a6f2e7e816941702dd917d33b252b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:06:25 +0100 Subject: [PATCH 2/3] Reduce graph size through writing indexes directly into graph for map_blocks --- xarray/core/parallel.py | 12 ++++++++++-- xarray/tests/test_dask.py | 5 +++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 1de1c01ac7b..4bd68cbd4a7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -412,8 +412,8 @@ def _wrapper( try: import dask import dask.array - from dask.highlevelgraph import HighLevelGraph from dask.base import tokenize + from dask.highlevelgraph import HighLevelGraph except ImportError: pass @@ -579,7 +579,15 @@ def _wrapper( } from_wrapper = (gname,) + chunk_tuple - graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected, (dict, [[k, v] for k, v in tokenized_indexes.items()])) + graph[from_wrapper] = ( + _wrapper, + func, + blocked_args, + kwargs, + is_array, + expected, + (dict, [[k, v] for k, v in tokenized_indexes.items()]), + ) # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d331976e367..cc795b75118 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd import pytest -from distributed import LocalCluster import xarray as xr from xarray import DataArray, Dataset, Variable @@ -1380,7 +1379,9 @@ def test_map_blocks_da_ds_with_template(obj): # Check that indexes are written into the graph directly dsk = dict(actual.__dask_graph__()) assert len({k for k in dsk if "x-coordinate" in k}) - assert all(isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k) + assert all( + isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k + ) with raise_if_dask_computes(): actual = obj.map_blocks(func, template=template) From 8548b048bbd0c5d9dfecb4523d4d9c95913d7b84 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 22 Oct 2024 09:55:01 -0600 Subject: [PATCH 3/3] Update xarray/core/parallel.py --- xarray/core/parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 4bd68cbd4a7..a0dfe56807b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -552,6 +552,7 @@ def _wrapper( for isxr, arg in zip(is_xarray, npargs, strict=True) ] + # only include new or modified indexes to minimize duplication of data indexes = { dim: coordinates.xindexes[dim][ _get_chunk_slicer(dim, chunk_index, output_chunk_bounds)