diff --git a/pyproject.toml b/pyproject.toml index 3c4d11cdfb..2741d35bbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dace-cartesian = [ 'dace>=1.0.2,<2' # renfined in [tool.uv.sources] ] dace-next = [ - 'dace==2025.10.30' # refined in [tool.uv.sources] + 'dace==2025.11.04' # refined in [tool.uv.sources] and I like that book. ] dev = [ {include-group = 'build'}, @@ -240,6 +240,11 @@ disallow_incomplete_defs = false disallow_untyped_defs = false module = 'gt4py.next.iterator.*' +[[tool.mypy.overrides]] +disallow_incomplete_defs = false +disallow_untyped_defs = false +module = 'gt4py.next.program_processors.runners.dace_iterator.*' + [[tool.mypy.overrides]] ignore_errors = true module = 'gt4py.next.iterator.runtime' @@ -448,7 +453,7 @@ url = 'https://test.pypi.org/simple' atlas4py = {index = "test.pypi"} dace = [ {git = "https://github.com/GridTools/dace", branch = "romanc/stree-roundtrip", group = "dace-cartesian"}, - {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_10_30", group = "dace-next"} + {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_04", group = "dace-next"} ] # -- versioningit -- diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 3379ee8be3..9ffafeab0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -10,6 +10,7 @@ import dataclasses import os +import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any @@ -36,6 +37,14 @@ class CompiledDaceProgram(stages.CompiledProgram): None, ] + # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` + # means that it has not been initialized, i.e. no call was ever performed. + # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. + # - csdfg_init_argv: Arguments used for initialization; used only the first time and + # never updated. + csdfg_argv: MutableSequence[Any] | None + csdfg_init_argv: Sequence[Any] | None + def __init__( self, program: dace.CompiledSDFG, @@ -56,12 +65,35 @@ def __init__( # For debug purpose, we set a unique module name on the compiled function. self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) - def __call__(self, **kwargs: Any) -> None: - result = self.sdfg_program(**kwargs) - assert result is None + # Since the SDFG hasn't been called yet. + self.csdfg_argv = None + self.csdfg_init_argv = None + + def construct_arguments(self, **kwargs: Any) -> None: + """ + This function will process the arguments and store the processed values in `self.csdfg_args`, + to call them use `self.fast_call()`. + """ + with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): + csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) + # Note we only care about `csdfg_argv` (normal call), since we have to update it, + # we ensure that it is a `list`. + self.csdfg_argv = [*csdfg_argv] + self.csdfg_init_argv = csdfg_init_argv def fast_call(self) -> None: - result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs) + """Perform a call to the compiled SDFG using the processed arguments, see `self.prepare_arguments()`.""" + assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( + "Argument vector was not set properly." + ) + self.sdfg_program.fast_call(self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=False) + + def __call__(self, **kwargs: Any) -> None: + warnings.warn( + "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", + stacklevel=1, + ) + result = self.sdfg_program(**kwargs) assert result is None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index c11f691826..996ba7a095 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -11,7 +11,6 @@ import functools from typing import Any, Sequence -import dace import numpy as np from gt4py._core import definitions as core_defs @@ -45,15 +44,17 @@ def decorated_program( args = (*args, out) try: - # Initialization of `_lastargs` was done by the `CompiledSDFG` object, - # so we just update it with the current call arguments. - update_sdfg_call_args(args, fun.sdfg_program._lastargs[0]) - fun.fast_call() - except IndexError: - # First call, the SDFG is not initialized, so forward the call to `CompiledSDFG` - # to properly initialize it. Later calls to this SDFG will be handled through - # the `fast_call()` API. - assert len(fun.sdfg_program._lastargs) == 0 # `fun.sdfg_program._lastargs` is empty + # Not the first call. + # We will only update the argument vector for the normal call. + # NOTE: If this is the first time then we will generate an exception because + # `fun.csdfg_args` is `None` + # TODO(phimuell, edopao): Think about refactor the code such that the update + # of the argument vector is a Method of the `CompiledDaceProgram`. + update_sdfg_call_args(args, fun.csdfg_argv) # type: ignore[arg-type] # Will error out in first call. + + except TypeError: + # First call. Construct the initial argument vector of the `CompiledDaceProgram`. + assert fun.csdfg_argv is None and fun.csdfg_init_argv is None flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(args) this_call_args = sdfg_callable.get_sdfg_args( fun.sdfg_program.sdfg, @@ -65,8 +66,10 @@ def decorated_program( gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.COLLECT_METRICS_LEVEL, gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME: collect_time_arg, } - with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): - fun(**this_call_args) + fun.construct_arguments(**this_call_args) + + # Perform the call to the SDFG. + fun.fast_call() if collect_time: metric_source = metrics.get_current_source() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index a2ac5af494..a204886690 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -81,36 +81,44 @@ def unstructured_case(request, exec_alloc_descriptor, mesh_descriptor): def make_mocks(monkeypatch): # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object mock_fast_call = unittest.mock.MagicMock() - dace_fast_call = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call + gt4py_fast_call = ( + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.fast_call + ) - def mocked_fast_call(self, *args, **kwargs): - mock_fast_call.__call__(*args, **kwargs) - fast_call_result = dace_fast_call(self, *args, **kwargs) + def mocked_fast_call(self): + mock_fast_call.__call__() + fast_call_result = gt4py_fast_call(self) # invalidate all scalar positional arguments to ensure that they are properly set # next time the SDFG is executed before fast_call - positional_args = set(self.sdfg.arg_names) - sdfg_arglist = self.sdfg.arglist() + positional_args = set(self.sdfg_program.sdfg.arg_names) + sdfg_arglist = self.sdfg_program.sdfg.arglist() for i, (arg_name, arg_type) in enumerate(sdfg_arglist.items()): if arg_name in positional_args and isinstance(arg_type, dace.data.Scalar): - assert isinstance(self._lastargs[0][i], ctypes.c_int) - self._lastargs[0][i].value = -1 + assert isinstance(self.csdfg_argv[i], ctypes.c_int) + self.csdfg_argv[i].value = -1 return fast_call_result - monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call) + monkeypatch.setattr( + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, + "fast_call", + mocked_fast_call, + ) - # Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object - mock_construct_args = unittest.mock.MagicMock() - dace_construct_args = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args + # Wrap `compiled_sdfg.CompiledSDFG.construct_arguments` with mock object + mock_construct_arguments = unittest.mock.MagicMock() + gt4py_construct_arguments = gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram.construct_arguments - def mocked_construct_args(self, *args, **kwargs): - mock_construct_args.__call__(*args, **kwargs) - return dace_construct_args(self, *args, **kwargs) + def mocked_construct_arguments(self, *args, **kwargs): + mock_construct_arguments.__call__(*args, **kwargs) + return gt4py_construct_arguments(self, *args, **kwargs) monkeypatch.setattr( - dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args + gtx.program_processors.runners.dace.workflow.compilation.CompiledDaceProgram, + "construct_arguments", + mocked_construct_arguments, ) - return mock_fast_call, mock_construct_args + return mock_fast_call, mock_construct_arguments def test_dace_fastcall(cartesian_case, monkeypatch): @@ -139,11 +147,11 @@ def testee( unused_field = cases.allocate(cartesian_case, testee, "unused_field")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - mock_fast_call, mock_construct_args = make_mocks(monkeypatch) + mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch) # Reset mock objects and run/verify GT4Py program def verify_testee(): - mock_construct_args.reset_mock() + mock_construct_arguments.reset_mock() mock_fast_call.reset_mock() cases.verify( cartesian_case, @@ -159,24 +167,24 @@ def verify_testee(): # On first run, the SDFG arguments will have to be constructed verify_testee() - mock_construct_args.assert_called_once() + mock_construct_arguments.assert_called_once() # Now modify the scalar arguments, used and unused ones: reuse previous SDFG arguments for i in range(4): a_offset[i] += 1 verify_testee() - mock_construct_args.assert_not_called() + mock_construct_arguments.assert_not_called() # Modify content of current buffer: reuse previous SDFG arguments for buff in (a, unused_field): buff[0] += 1 verify_testee() - mock_construct_args.assert_not_called() + mock_construct_arguments.assert_not_called() # Pass a new buffer, fastcall API should still be used a = cases.allocate(cartesian_case, testee, "a")() verify_testee() - mock_construct_args.assert_not_called() + mock_construct_arguments.assert_not_called() def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): @@ -191,11 +199,11 @@ def testee(a: cases.VField) -> cases.EField: (a,), kwfields = cases.get_default_data(unstructured_case, testee) numpy_ref = lambda a: a[connectivity_E2V[:, 0]] - mock_fast_call, mock_construct_args = make_mocks(monkeypatch) + mock_fast_call, mock_construct_arguments = make_mocks(monkeypatch) # Reset mock objects and run/verify GT4Py program def verify_testee(): - mock_construct_args.reset_mock() + mock_construct_arguments.reset_mock() mock_fast_call.reset_mock() cases.verify( unstructured_case, @@ -209,4 +217,4 @@ def verify_testee(): verify_testee() verify_testee() - mock_construct_args.assert_not_called() + mock_construct_arguments.assert_not_called() diff --git a/uv.lock b/uv.lock index 1a387af01a..fcde25b0cb 100644 --- a/uv.lock +++ b/uv.lock @@ -947,8 +947,8 @@ dependencies = [ [[package]] name = "dace" -version = "2025.10.30" -source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_10_30#6ff7bf16174434ba7d842f2a176789829440c131" } +version = "25.11.4" +source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_04#4f9c76f91de70b99a3ed483c349af30c25648ec1" } resolution-markers = [ "python_full_version >= '3.13'", "python_full_version == '3.12.*'", @@ -1429,7 +1429,7 @@ dace-cartesian = [ { name = "dace", version = "1.0.2", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-roundtrip#1033dfcf9d118856d82c6ee8d6f6cfacec662335" } }, ] dace-next = [ - { name = "dace", version = "2025.10.30", source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_10_30#6ff7bf16174434ba7d842f2a176789829440c131" } }, + { name = "dace", version = "25.11.4", source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_04#4f9c76f91de70b99a3ed483c349af30c25648ec1" } }, ] dev = [ { name = "atlas4py" }, @@ -1572,7 +1572,7 @@ build = [ { name = "wheel", specifier = ">=0.33.6" }, ] dace-cartesian = [{ name = "dace", git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-roundtrip" }] -dace-next = [{ name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_10_30" }] +dace-next = [{ name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_04" }] dev = [ { name = "atlas4py", specifier = ">=0.41", index = "https://test.pypi.org/simple" }, { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" },