Skip to content

Commit 08213ea

Browse files
feat[dace][next] More Roboust Calling (#2353)
This PR changes how calls to the underlying `CompiledSDFG` objects are carried out. Before, the implementation heavily relied on the internal data format of the DaCe class. However a recent [change in DaCe](spcl/dace#2185) changed this internal data leading to errors, that had to [be patched](GridTools/dace#9). This PR introduces a more stable fix and builds upon a [refactoring in DaCe](spcl/dace#2206) that beside other things, exposes the tools that were needed by GT4Py to work independently of the internals. For that reason this PR also updates the DaCe dependency to `2025.11.04`. The main change is, that the argument vector, i.e. the C representation of the arguments used for the call, are no longer managed by `CompiledSDFG` but instead by GT4Py's `CompiledDaceProgram`. Co-authored-by: edopao <[email protected]>
1 parent a75b58e commit 08213ea

File tree

5 files changed

+96
-48
lines changed

5 files changed

+96
-48
lines changed

pyproject.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dace-cartesian = [
1111
'dace>=1.0.2,<2' # renfined in [tool.uv.sources]
1212
]
1313
dace-next = [
14-
'dace==2025.10.30' # refined in [tool.uv.sources]
14+
'dace==2025.11.04' # refined in [tool.uv.sources] and I like that book.
1515
]
1616
dev = [
1717
{include-group = 'build'},
@@ -240,6 +240,11 @@ disallow_incomplete_defs = false
240240
disallow_untyped_defs = false
241241
module = 'gt4py.next.iterator.*'
242242

243+
[[tool.mypy.overrides]]
244+
disallow_incomplete_defs = false
245+
disallow_untyped_defs = false
246+
module = 'gt4py.next.program_processors.runners.dace_iterator.*'
247+
243248
[[tool.mypy.overrides]]
244249
ignore_errors = true
245250
module = 'gt4py.next.iterator.runtime'
@@ -448,7 +453,7 @@ url = 'https://test.pypi.org/simple'
448453
atlas4py = {index = "test.pypi"}
449454
dace = [
450455
{git = "https://github.com/GridTools/dace", branch = "romanc/stree-roundtrip", group = "dace-cartesian"},
451-
{git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_10_30", group = "dace-next"}
456+
{git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_04", group = "dace-next"}
452457
]
453458

454459
# -- versioningit --

src/gt4py/next/program_processors/runners/dace/workflow/compilation.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import dataclasses
1212
import os
13+
import warnings
1314
from collections.abc import Callable, MutableSequence, Sequence
1415
from typing import Any
1516

@@ -36,6 +37,14 @@ class CompiledDaceProgram(stages.CompiledProgram):
3637
None,
3738
]
3839

40+
# Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None`
41+
# means that it has not been initialized, i.e. no call was ever performed.
42+
# - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated.
43+
# - csdfg_init_argv: Arguments used for initialization; used only the first time and
44+
# never updated.
45+
csdfg_argv: MutableSequence[Any] | None
46+
csdfg_init_argv: Sequence[Any] | None
47+
3948
def __init__(
4049
self,
4150
program: dace.CompiledSDFG,
@@ -56,12 +65,35 @@ def __init__(
5665
# For debug purpose, we set a unique module name on the compiled function.
5766
self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder)
5867

59-
def __call__(self, **kwargs: Any) -> None:
60-
result = self.sdfg_program(**kwargs)
61-
assert result is None
68+
# Since the SDFG hasn't been called yet.
69+
self.csdfg_argv = None
70+
self.csdfg_init_argv = None
71+
72+
def construct_arguments(self, **kwargs: Any) -> None:
73+
"""
74+
This function will process the arguments and store the processed values in `self.csdfg_args`,
75+
to call them use `self.fast_call()`.
76+
"""
77+
with dace.config.set_temporary("compiler", "allow_view_arguments", value=True):
78+
csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs)
79+
# Note we only care about `csdfg_argv` (normal call), since we have to update it,
80+
# we ensure that it is a `list`.
81+
self.csdfg_argv = [*csdfg_argv]
82+
self.csdfg_init_argv = csdfg_init_argv
6283

6384
def fast_call(self) -> None:
64-
result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs)
85+
"""Perform a call to the compiled SDFG using the processed arguments, see `self.prepare_arguments()`."""
86+
assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, (
87+
"Argument vector was not set properly."
88+
)
89+
self.sdfg_program.fast_call(self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=False)
90+
91+
def __call__(self, **kwargs: Any) -> None:
92+
warnings.warn(
93+
"Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.",
94+
stacklevel=1,
95+
)
96+
result = self.sdfg_program(**kwargs)
6597
assert result is None
6698

6799

src/gt4py/next/program_processors/runners/dace/workflow/decoration.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import functools
1212
from typing import Any, Sequence
1313

14-
import dace
1514
import numpy as np
1615

1716
from gt4py._core import definitions as core_defs
@@ -45,15 +44,17 @@ def decorated_program(
4544
args = (*args, out)
4645

4746
try:
48-
# Initialization of `_lastargs` was done by the `CompiledSDFG` object,
49-
# so we just update it with the current call arguments.
50-
update_sdfg_call_args(args, fun.sdfg_program._lastargs[0])
51-
fun.fast_call()
52-
except IndexError:
53-
# First call, the SDFG is not initialized, so forward the call to `CompiledSDFG`
54-
# to properly initialize it. Later calls to this SDFG will be handled through
55-
# the `fast_call()` API.
56-
assert len(fun.sdfg_program._lastargs) == 0 # `fun.sdfg_program._lastargs` is empty
47+
# Not the first call.
48+
# We will only update the argument vector for the normal call.
49+
# NOTE: If this is the first time then we will generate an exception because
50+
# `fun.csdfg_args` is `None`
51+
# TODO(phimuell, edopao): Think about refactor the code such that the update
52+
# of the argument vector is a Method of the `CompiledDaceProgram`.
53+
update_sdfg_call_args(args, fun.csdfg_argv) # type: ignore[arg-type] # Will error out in first call.
54+
55+
except TypeError:
56+
# First call. Construct the initial argument vector of the `CompiledDaceProgram`.
57+
assert fun.csdfg_argv is None and fun.csdfg_init_argv is None
5758
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(args)
5859
this_call_args = sdfg_callable.get_sdfg_args(
5960
fun.sdfg_program.sdfg,
@@ -65,8 +66,10 @@ def decorated_program(
6566
gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.COLLECT_METRICS_LEVEL,
6667
gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME: collect_time_arg,
6768
}
68-
with dace.config.set_temporary("compiler", "allow_view_arguments", value=True):
69-
fun(**this_call_args)
69+
fun.construct_arguments(**this_call_args)
70+
71+
# Perform the call to the SDFG.
72+
fun.fast_call()
7073

7174
if collect_time:
7275
metric_source = metrics.get_current_source()

tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,36 +81,44 @@ def unstructured_case(request, exec_alloc_descriptor, mesh_descriptor):
8181
def make_mocks(monkeypatch):
8282
# Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object
8383
mock_fast_call = unittest.mock.MagicMock()
84-
dace_fast_call = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call
84+
gt4py_fast_call = (
85+
gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call
86+
)
8587

86-
def mocked_fast_call(self, *args, **kwargs):
87-
mock_fast_call.__call__(*args, **kwargs)
88-
fast_call_result = dace_fast_call(self, *args, **kwargs)
88+
def mocked_fast_call(self):
89+
mock_fast_call.__call__()
90+
fast_call_result = gt4py_fast_call(self)
8991
# invalidate all scalar positional arguments to ensure that they are properly set
9092
# next time the SDFG is executed before fast_call
91-
positional_args = set(self.sdfg.arg_names)
92-
sdfg_arglist = self.sdfg.arglist()
93+
positional_args = set(self.sdfg_program.sdfg.arg_names)
94+
sdfg_arglist = self.sdfg_program.sdfg.arglist()
9395
for i, (arg_name, arg_type) in enumerate(sdfg_arglist.items()):
9496
if arg_name in positional_args and isinstance(arg_type, dace.data.Scalar):
95-
assert isinstance(self._lastargs[0][i], ctypes.c_int)
96-
self._lastargs[0][i].value = -1
97+
assert isinstance(self.csdfg_argv[i], ctypes.c_int)
98+
self.csdfg_argv[i].value = -1
9799
return fast_call_result
98100

99-
monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call)
101+
monkeypatch.setattr(
102+
gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram,
103+
"fast_call",
104+
mocked_fast_call,
105+
)
100106

101-
# Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object
102-
mock_construct_args = unittest.mock.MagicMock()
103-
dace_construct_args = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args
107+
# Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object
108+
mock_construct_arguments = unittest.mock.MagicMock()
109+
gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments
104110

105-
def mocked_construct_args(self, *args, **kwargs):
106-
mock_construct_args.__call__(*args, **kwargs)
107-
return dace_construct_args(self, *args, **kwargs)
111+
def mocked_construct_arguments(self, *args, **kwargs):
112+
mock_construct_arguments.__call__(*args, **kwargs)
113+
return gt4py_construct_arguments(self, *args, **kwargs)
108114

109115
monkeypatch.setattr(
110-
dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args
116+
gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram,
117+
"construct_arguments",
118+
mocked_construct_arguments,
111119
)
112120

113-
return mock_fast_call, mock_construct_args
121+
return mock_fast_call, mock_construct_arguments
114122

115123

116124
def test_dace_fastcall(cartesian_case, monkeypatch):
@@ -139,11 +147,11 @@ def testee(
139147
unused_field = cases.allocate(cartesian_case, testee, "unused_field")()
140148
out = cases.allocate(cartesian_case, testee, cases.RETURN)()
141149

142-
mock_fast_call, mock_construct_args = make_mocks(monkeypatch)
150+
mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch)
143151

144152
# Reset mock objects and run/verify GT4Py program
145153
def verify_testee():
146-
mock_construct_args.reset_mock()
154+
mock_construct_arguments.reset_mock()
147155
mock_fast_call.reset_mock()
148156
cases.verify(
149157
cartesian_case,
@@ -159,24 +167,24 @@ def verify_testee():
159167

160168
# On first run, the SDFG arguments will have to be constructed
161169
verify_testee()
162-
mock_construct_args.assert_called_once()
170+
mock_construct_arguments.assert_called_once()
163171

164172
# Now modify the scalar arguments, used and unused ones: reuse previous SDFG arguments
165173
for i in range(4):
166174
a_offset[i] += 1
167175
verify_testee()
168-
mock_construct_args.assert_not_called()
176+
mock_construct_arguments.assert_not_called()
169177

170178
# Modify content of current buffer: reuse previous SDFG arguments
171179
for buff in (a, unused_field):
172180
buff[0] += 1
173181
verify_testee()
174-
mock_construct_args.assert_not_called()
182+
mock_construct_arguments.assert_not_called()
175183

176184
# Pass a new buffer, fastcall API should still be used
177185
a = cases.allocate(cartesian_case, testee, "a")()
178186
verify_testee()
179-
mock_construct_args.assert_not_called()
187+
mock_construct_arguments.assert_not_called()
180188

181189

182190
def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch):
@@ -191,11 +199,11 @@ def testee(a: cases.VField) -> cases.EField:
191199
(a,), kwfields = cases.get_default_data(unstructured_case, testee)
192200
numpy_ref = lambda a: a[connectivity_E2V[:, 0]]
193201

194-
mock_fast_call, mock_construct_args = make_mocks(monkeypatch)
202+
mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch)
195203

196204
# Reset mock objects and run/verify GT4Py program
197205
def verify_testee():
198-
mock_construct_args.reset_mock()
206+
mock_construct_arguments.reset_mock()
199207
mock_fast_call.reset_mock()
200208
cases.verify(
201209
unstructured_case,
@@ -209,4 +217,4 @@ def verify_testee():
209217

210218
verify_testee()
211219
verify_testee()
212-
mock_construct_args.assert_not_called()
220+
mock_construct_arguments.assert_not_called()

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)