Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def _array_interface_ptr(array: Any, storage: dtypes.StorageType) -> int:


class CompiledSDFG(object):
""" A compiled SDFG object that can be called through Python. """
""" A compiled SDFG object that can be called through Python.

Todo:
Scalar return values are not handled properly, this is a code gen issue.
"""

def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None):
from dace.sdfg import SDFG
Expand Down Expand Up @@ -675,6 +679,7 @@ def _initialize_return_values(self, kwargs):

def _convert_return_values(self):
# Return the values as they would be from a Python function
# NOTE: Currently it is not possible to return a scalar value, see `tests/sdfg/scalar_return.py`
if self._return_arrays is None or len(self._return_arrays) == 0:
return None
elif len(self._return_arrays) == 1:
Expand Down
9 changes: 7 additions & 2 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import os
from typing import TYPE_CHECKING, Dict, List, Set
import warnings
from dace import dtypes, subsets
from dace import symbolic
from dace import dtypes, subsets, symbolic

if TYPE_CHECKING:
import dace
Expand Down Expand Up @@ -185,6 +184,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
on failure.
"""
# Avoid import loop
from dace import data as dt
from dace.codegen.targets import fpga
from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga

Expand Down Expand Up @@ -215,6 +215,11 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
'rather than using multiple references to the same one', sdfg, None)
references.add(id(desc))

# Because of how the code generator works Scalars can not be return values.
# TODO: Remove this limitation as the CompiledSDFG contains logic for that.
if isinstance(desc, dt.Scalar) and name.startswith("__return") and not desc.transient:
raise InvalidSDFGError(f'Can not use scalar "{name}" as return value.', sdfg, None)

# Validate array names
if name is not None and not dtypes.validate_name(name):
raise InvalidSDFGError("Invalid array name %s" % name, sdfg, None)
Expand Down
116 changes: 116 additions & 0 deletions tests/sdfg/scalar_return.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
import pytest
from typing import Tuple

from dace.sdfg.validation import InvalidSDFGError

def single_retval_sdfg() -> dace.SDFG:

@dace.program(auto_optimize=False, recreate_sdfg=True)
def testee(
A: dace.float64[20],
) -> dace.float64:
return dace.float64(A[3])

return testee.to_sdfg(validate=False)


def tuple_retval_sdfg() -> dace.SDFG:

# This can not be used, as the frontend promotes the two scalars inside the tuple
# to arrays of length one.
#@dace.program(auto_optimize=False, recreate_sdfg=True)
#def testee(
# a: dace.float64,
# b: dace.float64,
#) -> Tuple[dace.float64, dace.float64]:
# return a + b, a - b

sdfg = dace.SDFG("scalar_tuple_return")
state = sdfg.add_state("init", is_start_block=True)
anames = ["a", "b"]
sdfg.add_scalar(anames[0], dace.float64)
sdfg.add_scalar(anames[1], dace.float64)
sdfg.add_scalar("__return_0", dace.float64)
sdfg.add_scalar("__return_1", dace.float64)
acnodes = {aname: state.add_access(aname) for aname in anames}

for iout, ops in enumerate(["+", "-"]):
tskl = state.add_tasklet(
"work",
inputs={"__in0", "__in1"},
outputs={"__out"},
code=f"__out0 = __in0 {ops} __in1",
)
for isrc, src in enumerate(anames):
state.add_edge(
acnodes[src],
None,
tskl,
f"__in{isrc}",
dace.Memlet.simple(src, "0")
)
state.add_edge(
tskl,
"__out",
state.add_write(f"__return_{iout}"),
None,
dace.Memlet.simple(f"__return_{iout}", "0"),
)
return sdfg


@pytest.mark.skip("Scalar return is not implement.")
def test_scalar_return():

sdfg = single_retval_sdfg()
assert isinstance(sdfg.arrays["__return"], dace.data.Scalar)

sdfg.validate()
A = np.random.rand(20)
res = sdfg(A=A)
assert isinstance(res, np.float64)
assert A[3] == res


@pytest.mark.skip("Scalar return is not implement.")
def test_scalar_return_tuple():

sdfg = tuple_retval_sdfg()
assert all(
isinstance(desc, dace.data.Scalar)
for name, desc in sdfg.arrays.items()
if name.startswith("__return")
)

sdfg.validate()
a, b = np.float64(23.9), np.float64(10.0)
res1, res2 = sdfg(a=a, b=b)
assert all(isinstance(res, np.float64) for res in (ret1, ret2))
assert np.isclose(res1 == (a + b))
assert np.isclose(res2 == (a - b))


def test_scalar_return_validation():
"""Test if the validation actually works.

Todo:
Remove this test after scalar return values are implemented and enable
the `test_scalar_return` and `test_scalar_return_tuple()` tests.
"""

sdfg = single_retval_sdfg()
with pytest.raises(
InvalidSDFGError,
match='Can not use scalar "__return" as return value.',
):
sdfg.validate()

sdfg = tuple_retval_sdfg()
with pytest.raises(
InvalidSDFGError,
match='Can not use scalar "__return_(0|1)" as return value.',
):
sdfg.validate()