-
Notifications
You must be signed in to change notification settings - Fork 54
feat[dace][next] More Roboust Calling #2353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
78b8dce
e6133ce
366293b
17be22e
c228fc8
dab085b
5280c3a
4ab5227
efb562a
82fae60
b945c28
969405a
02e7375
94332f0
cec9e08
45ecdb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
|
|
||
| import dataclasses | ||
| import os | ||
| import warnings | ||
| from collections.abc import Callable, MutableSequence, Sequence | ||
| from typing import Any | ||
|
|
||
|
|
@@ -36,6 +37,14 @@ class CompiledDaceProgram(stages.CompiledProgram): | |
| None, | ||
| ] | ||
|
|
||
| # Processed argument vectors that are passed to `CompiledSDFG.fast_call()`. `None` | ||
| # means that it has not been initialized, i.e. no call was ever performed. | ||
| # - csdfg_argv: Arguments used for calling the actual compiled SDFG, will be updated. | ||
| # - csdfg_init_argv: Arguments used for initialization; used only the first time and | ||
| # never updated. | ||
| csdfg_argv: MutableSequence[Any] | None | ||
| csdfg_init_argv: Sequence[Any] | None | ||
|
|
||
| def __init__( | ||
| self, | ||
| program: dace.CompiledSDFG, | ||
|
|
@@ -56,12 +65,35 @@ def __init__( | |
| # For debug purpose, we set a unique module name on the compiled function. | ||
| self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) | ||
|
|
||
| def __call__(self, **kwargs: Any) -> None: | ||
| result = self.sdfg_program(**kwargs) | ||
| assert result is None | ||
| # Since the SDFG hasn't been called yet. | ||
| self.csdfg_argv = None | ||
| self.csdfg_init_argv = None | ||
|
|
||
| def construct_arguments(self, **kwargs: Any) -> None: | ||
| """ | ||
| This function will process the arguments and store the processed values in `self.csdfg_args`, | ||
| to call them use `self.fast_call()`. | ||
| """ | ||
| with dace.config.set_temporary("compiler", "allow_view_arguments", value=True): | ||
| csdfg_argv, csdfg_init_argv = self.sdfg_program.construct_arguments(**kwargs) | ||
| # Note we only care about `csdfg_argv` (normal call), since we have to update it, | ||
| # we ensure that it is a `list`. | ||
| self.csdfg_argv = [*csdfg_argv] | ||
| self.csdfg_init_argv = csdfg_init_argv | ||
|
|
||
| def fast_call(self) -> None: | ||
| result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs) | ||
| """Perform a call to the compiled SDFG using the processed arguments, see `self.prepare_arguments()`.""" | ||
philip-paul-mueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert self.csdfg_argv is not None and self.csdfg_init_argv is not None, ( | ||
| "Argument vector was not set properly." | ||
| ) | ||
| self.sdfg_program.fast_call(self.csdfg_argv, self.csdfg_init_argv, do_gpu_check=False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we really sure we want to pass
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @philip-paul-mueller , any thoughts on this comment?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should have added something here. |
||
|
|
||
| def __call__(self, **kwargs: Any) -> None: | ||
| warnings.warn( | ||
| "Called an SDFG through the standard DaCe interface is not recommended, use `fast_call()` instead.", | ||
| stacklevel=1, | ||
| ) | ||
| result = self.sdfg_program(**kwargs) | ||
| assert result is None | ||
|
|
||
|
|
||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to remove these lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@philip-paul-mueller did you actually remove these lines and the "and I like that book" comment of the
dace-nextversion before merging?