Skip to content

Commit 2ba58f2

Browse files
fix: uproot was exposed in one place to dask's _task_spec overhaul (#1352)
* fix: uproot was exposed in one plat to dask's _task_spec overhaul * style: pre-commit fixes * require fixed dask-awkward --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8a71b73 commit 2ba58f2

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ requires-python = ">=3.9"
5757
[project.optional-dependencies]
5858
dev = [
5959
"boost_histogram>=0.13",
60-
"dask-awkward>=2023.12.1",
60+
"dask-awkward>=2024.12.1",
6161
"dask[array,distributed]",
6262
"hist>=1.2",
6363
"pandas",

src/uproot/_dask.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def _dask_array_from_map(
383383
**kwargs,
384384
):
385385
dask = uproot.extras.dask()
386+
_dask_uses_tasks = hasattr(dask, "_task_spec")
387+
386388
da = uproot.extras.dask_array()
387389
if not callable(func):
388390
raise ValueError("`func` argument must be `callable`")
@@ -446,14 +448,22 @@ def _dask_array_from_map(
446448
produces_tasks=produces_tasks,
447449
)
448450

449-
dsk = dask.blockwise.Blockwise(
450-
output=name,
451-
output_indices="i",
452-
dsk={name: (io_func, dask.blockwise.blockwise_token(0))},
453-
indices=[(io_arg_map, "i")],
454-
numblocks={},
455-
annotations=None,
456-
)
451+
blockwise_kwargs = {
452+
"output": name,
453+
"output_indices": "i",
454+
"indices": [(io_arg_map, "i")],
455+
"numblocks": {},
456+
"annotations": None,
457+
}
458+
459+
if _dask_uses_tasks:
460+
blockwise_kwargs["task"] = dask._task_spec.Task(
461+
name, io_func, dask._task_spec.TaskRef(dask.blockwise.blockwise_token(0))
462+
)
463+
else:
464+
blockwise_kwargs["dsk"] = {name: (io_func, dask.blockwise.blockwise_token(0))}
465+
466+
dsk = dask.blockwise.Blockwise(**blockwise_kwargs)
457467

458468
hlg = dask.highlevelgraph.HighLevelGraph.from_collections(name, dsk)
459469
return da.core.Array(hlg, name, chunks, dtype=dtype)

0 commit comments

Comments
 (0)