Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dace-cartesian = [
'dace>=1.0.2,<2' # renfined in [tool.uv.sources]
]
dace-next = [
'dace==2025.10.30' # refined in [tool.uv.sources]
'dace==2025.11.04' # refined in [tool.uv.sources] and I like that book.
]
dev = [
{include-group = 'build'},
Expand Down Expand Up @@ -240,6 +240,11 @@ disallow_incomplete_defs = false
disallow_untyped_defs = false
module = 'gt4py.next.iterator.*'

[[tool.mypy.overrides]]
disallow_incomplete_defs = false
disallow_untyped_defs = false
module = 'gt4py.next.program_processors.runners.dace_iterator.*'

Comment on lines +243 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to remove these lines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@philip-paul-mueller did you actually remove these lines and the "and I like that book" comment of the dace-next version before merging?

[[tool.mypy.overrides]]
ignore_errors = true
module = 'gt4py.next.iterator.runtime'
Expand Down Expand Up @@ -448,7 +453,7 @@ url = 'https://test.pypi.org/simple'
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/GridTools/dace", branch = "romanc/stree-roundtrip", group = "dace-cartesian"},
{git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_10_30", group = "dace-next"}
{git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_04", group = "dace-next"}
]

# -- versioningit --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import dataclasses
import os
import warnings
from collections.abc import Callable, MutableSequence, Sequence
from typing import Any

Expand All @@ -36,6 +37,14 @@ class CompiledDaceProgram(stages.CompiledProgram):
None,
]

# Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None`
# means that it has not been initialized, i.e. no call was ever performed.
# - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated.
# - csdfg_init_argv: Arguments used for initialization; used only the first time and
# never updated.
csdfg_argv: MutableSequence[Any] | None
csdfg_init_argv: Sequence[Any] | None

def __init__(
self,
program: dace.CompiledSDFG,
Expand All @@ -56,12 +65,35 @@ def __init__(
# For debug purpose, we set a unique module name on the compiled function.
self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder)

def __call__(self, **kwargs: Any) -> None:
result = self.sdfg_program(**kwargs)
assert result is None
# Since the SDFG hasn't been called yet.
self.csdfg_argv = None
self.csdfg_init_argv = None

def construct_arguments(self, **kwargs: Any) -> None:
"""
This function will process the arguments and store the processed values in `self.csdfg_args`,
to call them use `self.fast_call()`.
"""
with dace.config.set_temporary("compiler", "allow_view_arguments", value=True):
csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs)
# Note we only care about `csdfg_argv` (normal call), since we have to update it,
# we ensure that it is a `list`.
self.csdfg_argv = [*csdfg_argv]
self.csdfg_init_argv = csdfg_init_argv

def fast_call(self) -> None:
result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs)
"""Perform a call to the compiled SDFG using the processed arguments, see `self.prepare_arguments()`."""
assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, (
"Argument vector was not set properly."
)
self.sdfg_program.fast_call(self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we really sure we want to pass do_gpu_check=False? Do you know what is its overhead? Maybe we could pass do_gpu_check=__debug__?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@philip-paul-mueller , any thoughts on this comment?

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2365

We should have added something here.


def __call__(self, **kwargs: Any) -> None:
warnings.warn(
"Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.",
stacklevel=1,
)
result = self.sdfg_program(**kwargs)
assert result is None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import functools
from typing import Any, Sequence

import dace
import numpy as np

from gt4py._core import definitions as core_defs
Expand Down Expand Up @@ -45,15 +44,17 @@ def decorated_program(
args = (*args, out)

try:
# Initialization of `_lastargs` was done by the `CompiledSDFG` object,
# so we just update it with the current call arguments.
update_sdfg_call_args(args, fun.sdfg_program._lastargs[0])
fun.fast_call()
except IndexError:
# First call, the SDFG is not initialized, so forward the call to `CompiledSDFG`
# to properly initialize it. Later calls to this SDFG will be handled through
# the `fast_call()` API.
assert len(fun.sdfg_program._lastargs) == 0 # `fun.sdfg_program._lastargs` is empty
# Not the first call.
# We will only update the argument vector for the normal call.
# NOTE: If this is the first time then we will generate an exception because
# `fun.csdfg_args` is `None`
# TODO(phimuell, edopao): Think about refactor the code such that the update
# of the argument vector is a Method of the `CompiledDaceProgram`.
update_sdfg_call_args(args, fun.csdfg_argv) # type: ignore[arg-type] # Will error out in first call.

except TypeError:
# First call. Construct the initial argument vector of the `CompiledDaceProgram`.
assert fun.csdfg_argv is None and fun.csdfg_init_argv is None
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(args)
this_call_args = sdfg_callable.get_sdfg_args(
fun.sdfg_program.sdfg,
Expand All @@ -65,8 +66,10 @@ def decorated_program(
gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.COLLECT_METRICS_LEVEL,
gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME: collect_time_arg,
}
with dace.config.set_temporary("compiler", "allow_view_arguments", value=True):
fun(**this_call_args)
fun.construct_arguments(**this_call_args)

# Perform the call to the SDFG.
fun.fast_call()

if collect_time:
metric_source = metrics.get_current_source()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,36 +81,44 @@ def unstructured_case(request, exec_alloc_descriptor, mesh_descriptor):
def make_mocks(monkeypatch):
# Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object
mock_fast_call = unittest.mock.MagicMock()
dace_fast_call = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call
gt4py_fast_call = (
gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call
)

def mocked_fast_call(self, *args, **kwargs):
mock_fast_call.__call__(*args, **kwargs)
fast_call_result = dace_fast_call(self, *args, **kwargs)
def mocked_fast_call(self):
mock_fast_call.__call__()
fast_call_result = gt4py_fast_call(self)
# invalidate all scalar positional arguments to ensure that they are properly set
# next time the SDFG is executed before fast_call
positional_args = set(self.sdfg.arg_names)
sdfg_arglist = self.sdfg.arglist()
positional_args = set(self.sdfg_program.sdfg.arg_names)
sdfg_arglist = self.sdfg_program.sdfg.arglist()
for i, (arg_name, arg_type) in enumerate(sdfg_arglist.items()):
if arg_name in positional_args and isinstance(arg_type, dace.data.Scalar):
assert isinstance(self._lastargs[0][i], ctypes.c_int)
self._lastargs[0][i].value = -1
assert isinstance(self.csdfg_argv[i], ctypes.c_int)
self.csdfg_argv[i].value = -1
return fast_call_result

monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call)
monkeypatch.setattr(
gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram,
"fast_call",
mocked_fast_call,
)

# Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object
mock_construct_args = unittest.mock.MagicMock()
dace_construct_args = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args
# Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object
mock_construct_arguments = unittest.mock.MagicMock()
gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments

def mocked_construct_args(self, *args, **kwargs):
mock_construct_args.__call__(*args, **kwargs)
return dace_construct_args(self, *args, **kwargs)
def mocked_construct_arguments(self, *args, **kwargs):
mock_construct_arguments.__call__(*args, **kwargs)
return gt4py_construct_arguments(self, *args, **kwargs)

monkeypatch.setattr(
dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args
gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram,
"construct_arguments",
mocked_construct_arguments,
)

return mock_fast_call, mock_construct_args
return mock_fast_call, mock_construct_arguments


def test_dace_fastcall(cartesian_case, monkeypatch):
Expand Down Expand Up @@ -139,11 +147,11 @@ def testee(
unused_field = cases.allocate(cartesian_case, testee, "unused_field")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()

mock_fast_call, mock_construct_args = make_mocks(monkeypatch)
mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch)

# Reset mock objects and run/verify GT4Py program
def verify_testee():
mock_construct_args.reset_mock()
mock_construct_arguments.reset_mock()
mock_fast_call.reset_mock()
cases.verify(
cartesian_case,
Expand All @@ -159,24 +167,24 @@ def verify_testee():

# On first run, the SDFG arguments will have to be constructed
verify_testee()
mock_construct_args.assert_called_once()
mock_construct_arguments.assert_called_once()

# Now modify the scalar arguments, used and unused ones: reuse previous SDFG arguments
for i in range(4):
a_offset[i] += 1
verify_testee()
mock_construct_args.assert_not_called()
mock_construct_arguments.assert_not_called()

# Modify content of current buffer: reuse previous SDFG arguments
for buff in (a, unused_field):
buff[0] += 1
verify_testee()
mock_construct_args.assert_not_called()
mock_construct_arguments.assert_not_called()

# Pass a new buffer, fastcall API should still be used
a = cases.allocate(cartesian_case, testee, "a")()
verify_testee()
mock_construct_args.assert_not_called()
mock_construct_arguments.assert_not_called()


def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch):
Expand All @@ -191,11 +199,11 @@ def testee(a: cases.VField) -> cases.EField:
(a,), kwfields = cases.get_default_data(unstructured_case, testee)
numpy_ref = lambda a: a[connectivity_E2V[:, 0]]

mock_fast_call, mock_construct_args = make_mocks(monkeypatch)
mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch)

# Reset mock objects and run/verify GT4Py program
def verify_testee():
mock_construct_args.reset_mock()
mock_construct_arguments.reset_mock()
mock_fast_call.reset_mock()
cases.verify(
unstructured_case,
Expand All @@ -209,4 +217,4 @@ def verify_testee():

verify_testee()
verify_testee()
mock_construct_args.assert_not_called()
mock_construct_arguments.assert_not_called()
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.