Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ class CompiledDaceProgram(stages.CompiledProgram):
]

# Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None`
# means that it has not been initialized, i.e. no call was ever performed. The
# first argument vector is used in normal calls and will get updated, to avoid
# full argument reprocessing. The second argument vector is only needed for
# initialization and will never be updated.
csdfg_args: tuple[list[Any], Sequence[Any]] | None
# means that it has not been initialized, i.e. no call was ever performed.
# - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated.
# - csdfg_init_argv: Arguments used for initialization; used only the first time and
# never updated.
csdfg_argv: MutableSequence[Any] | None
csdfg_init_argv: Sequence[Any] | None

def __init__(
self,
Expand All @@ -64,26 +65,28 @@ def __init__(
# For debug purpose, we set a unique module name on the compiled function.
self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder)

# Since the sdfg program hasn't been called yet.
self.csdfg_args = None
# Since the SDFG hasn't been called yet.
self.csdfg_argv = None
self.csdfg_init_argv = None

def prepare_arguments(self, **kwargs: Any) -> None:
def construct_arguments(self, **kwargs: Any) -> None:
"""
This function will process the arguments and store the processed values in `self.csdfg_args`,
to call them use `self.fast_call()`.
"""
with dace.config.set_temporary("compiler", "allow_view_arguments", value=True):
csdfg_agrv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs)
# Note we only care about the first argument vector, that is used in normal call.
# Since we update it, we ensure that it is a `list`.
self.csdfg_args = ([*csdfg_agrv], csdfg_init_argv)
csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs)
# Note we only care about `csdfg_argv` (normal call), since we have to update it,
# we ensure that it is a `list`.
self.csdfg_argv = [*csdfg_argv]
self.csdfg_init_argv = csdfg_init_argv

def fast_call(self) -> None:
"""Perform a call to the compiled SDFG using the processed arguments, see `self.prepare_arguments()`."""
assert isinstance(self.csdfg_args, tuple) and len(self.csdfg_args) == 2, (
assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, (
"Argument vector was not set properly."
)
self.sdfg_program.fast_call(*self.csdfg_args)
self.sdfg_program.fast_call(self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we really sure we want to pass do_gpu_check=False? Do you know what is its overhead? Maybe we could pass do_gpu_check=__debug__?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@philip-paul-mueller , any thoughts on this comment?

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2365

We should have added something here.


def __call__(self, **kwargs: Any) -> None:
warnings.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ def decorated_program(

try:
# Not the first call.
# We will only update the first argument vector (arguments for the normal call).
# We will only update the argument vector for the normal call.
# NOTE: If this is the first time then we will generate an exception because
# `fun.csdfg_args` is `None`
update_sdfg_call_args(args, fun.csdfg_args[0]) # type: ignore[index] # Will error out in first call.
# TODO(phimuell, edopao): Think about refactor the code such that the update
# of the argument vector is a Method of the `CompiledDaceProgram`.
update_sdfg_call_args(args, fun.csdfg_argv) # type: ignore[arg-type] # Will error out in first call.

except TypeError:
# First call. Construct the initial argument vector of the `CompiledDaceProgram`.
assert fun.csdfg_args is None
assert fun.csdfg_argv is None and fun.csdfg_init_argv is None
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(args)
this_call_args = sdfg_callable.get_sdfg_args(
fun.sdfg_program.sdfg,
Expand All @@ -64,7 +66,7 @@ def decorated_program(
gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL: config.COLLECT_METRICS_LEVEL,
gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME: collect_time_arg,
}
fun.prepare_arguments(**this_call_args)
fun.construct_arguments(**this_call_args)

# Perform the call to the SDFG.
fun.fast_call()
Expand Down