Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
96f296f
Add mempcy and memset library nodes for expansion
ThrudPrimrose Sep 9, 2025
69cae77
Merge branch 'main' into memcpy_map_to_libnode_pass
ThrudPrimrose Sep 18, 2025
58d3d81
Add tests
ThrudPrimrose Sep 18, 2025
716f926
Fix lbi
ThrudPrimrose Sep 18, 2025
b998a3f
Fix
ThrudPrimrose Sep 18, 2025
5e7f13b
Fix yield edge
ThrudPrimrose Sep 18, 2025
f479f34
Fix yield edge for states
ThrudPrimrose Sep 18, 2025
8a4e60a
Merge branch 'yield-edge-fix' into memcpy_map_to_libnode_pass
ThrudPrimrose Sep 18, 2025
ea91324
F
ThrudPrimrose Sep 18, 2025
7c244eb
Fix
ThrudPrimrose Sep 18, 2025
6c2f2eb
Fix
ThrudPrimrose Sep 18, 2025
6185d8b
Fix things
ThrudPrimrose Sep 18, 2025
b7c95f7
Finalize pass cleanup
ThrudPrimrose Sep 19, 2025
936dd26
Fix
ThrudPrimrose Sep 22, 2025
b56da93
Refactor
ThrudPrimrose Sep 22, 2025
c5659f9
Refactor
ThrudPrimrose Sep 22, 2025
6c2f6ad
Rm unnecessary test
ThrudPrimrose Sep 22, 2025
0d99997
Run refactor
ThrudPrimrose Sep 22, 2025
cc4068c
Naming fixes, copyright
ThrudPrimrose Sep 22, 2025
fc8cb69
Fix
ThrudPrimrose Sep 23, 2025
864525c
Array dimension utility extension
ThrudPrimrose Nov 2, 2025
d68f46b
is packed and is contigous functions
ThrudPrimrose Nov 2, 2025
5089fa6
Run precommit
ThrudPrimrose Nov 2, 2025
b60a74c
Copyright and documentation
ThrudPrimrose Nov 2, 2025
be9b6a8
Update
ThrudPrimrose Nov 2, 2025
5e58dae
Merge branch 'is_packed_storage_utility' into memcpy_map_to_libnode_pass
ThrudPrimrose Nov 2, 2025
d2912d6
Update
ThrudPrimrose Nov 2, 2025
cf26a0f
Update, improve support for dynamic in connectors
ThrudPrimrose Nov 2, 2025
6ccb8b3
Add environment
ThrudPrimrose Nov 2, 2025
40c45f8
Asignment map to kernel fixes
ThrudPrimrose Nov 9, 2025
82a71b6
Fix minor issue in tests
ThrudPrimrose Nov 9, 2025
4cf3a8c
Attempt fixes
ThrudPrimrose Nov 10, 2025
291e621
Fix stuff
ThrudPrimrose Nov 10, 2025
bc3a8b9
Disable autoopt for the assignment / memcpy map to libnode stuff
ThrudPrimrose Nov 10, 2025
fe01c11
Fixes
ThrudPrimrose Nov 10, 2025
f8b6f13
things
ThrudPrimrose Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion dace/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
import aenum
import copy as cp
import ctypes
Expand Down Expand Up @@ -1680,6 +1680,35 @@ def set_shape(
self._set_shape_dependent_properties(new_shape, strides, total_size, offset)
self.validate()

def _get_packed_fortran_strides(self) -> Tuple[int]:
"""Compute packed strides, if the array is stored Fortran-style (column-major)."""
accum = 1
strides = []
for shape in self.shape:
strides.append(accum)
accum *= shape
return tuple(strides)

def _get_packed_c_strides(self) -> Tuple[int]:
"""Compute packed strides, if the array is stored C-styl (row-major)."""
accum = 1
strides = []
# Same as Fortran order if shape is inversed
for shape in reversed(self.shape):
strides.append(accum)
accum *= shape
return tuple(list(reversed(strides)))

def is_packed_fortran_strides(self) -> bool:
"""Return True if strides match Fortran-contiguous (column-major) layout."""
strides = self._get_packed_fortran_strides()
return tuple(strides) == tuple(self.strides)

def is_packed_c_strides(self) -> bool:
"""Return True if strides match Fortran-contiguous (row-major) layout."""
strides = self._get_packed_c_strides()
return tuple(strides) == tuple(self.strides)


@make_properties
class Stream(Data):
Expand Down
1 change: 1 addition & 0 deletions dace/libraries/standard/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from .cuda import CUDA
from .hptt import HPTT
from .cpu import CPU
21 changes: 21 additions & 0 deletions dace/libraries/standard/environments/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace.library


@dace.library.environment
class CPU:

cmake_minimum_version = None
cmake_packages = []
cmake_variables = {}
cmake_includes = []
cmake_libraries = []
cmake_compile_flags = []
cmake_link_flags = []
cmake_files = []

headers = []
state_fields = []
init_code = ""
finalize_code = ""
dependencies = []
2 changes: 1 addition & 1 deletion dace/libraries/standard/environments/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CUDA:
cmake_link_flags = []
cmake_files = []

headers = []
headers = {'frame': ["cuda_runtime.h"]}
state_fields = []
init_code = ""
finalize_code = ""
Expand Down
247 changes: 247 additions & 0 deletions dace/libraries/standard/nodes/copy_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace import library, nodes
from dace.transformation.transformation import ExpandTransformation
from .. import environments
from functools import reduce
import operator
from dace.codegen.common import sym2cpp
import copy


# Compute collapsed shapes and strides, removing singleton dimensions (length == 1)
def collapse_shape_and_strides(subset, strides):
collapsed_shape = []
collapsed_strides = []
for (b, e, s), stride in zip(subset, strides):
length = (e + 1 - b) // s
if length != 1:
collapsed_shape.append(length)
collapsed_strides.append(stride)
return collapsed_shape, collapsed_strides


def add_dynamic_inputs(dynamic_inputs, sdfg: dace.SDFG, in_subset: dace.subsets.Range, state: dace.SDFGState):
# Add dynamic inputs
pre_assignments = dict()
map_lengths = [dace.symbolic.SymExpr((e + 1 - b) // s) for (b, e, s) in in_subset]

for dynamic_input_name, datadesc in dynamic_inputs.items():
if dynamic_input_name in sdfg.arrays:
continue

if dynamic_input_name in sdfg.symbols:
sdfg.replace(str(dynamic_input_name), "sym_" + str(dynamic_input_name))
ndesc = copy.deepcopy(datadesc)
ndesc.transient = False
sdfg.add_datadesc(dynamic_input_name, ndesc)
# Should be scalar
if isinstance(ndesc, dace.data.Scalar):
pre_assignments["sym_" + dynamic_input_name] = f"{dynamic_input_name}"
else:
assert ndesc.shape == (1, ) or ndesc.shape == [
1,
]
pre_assignments["sym_" + dynamic_input_name] = f"{dynamic_input_name}[0]"

new_map_lengths = []
for ml in map_lengths:
nml = ml.subs({str(dynamic_input_name): "sym_" + str(dynamic_input_name)})
new_map_lengths.append(nml)
map_lengths = new_map_lengths

if pre_assignments != dict():
# Add a state for assignments in the beginning
sdfg.add_state_before(state=state, label="pre_assign", is_start_block=True, assignments=pre_assignments)

return map_lengths


@library.expansion
class ExpandPure(ExpandTransformation):
environments = []

@staticmethod
def expansion(node, parent_state, parent_sdfg):
inp_name, inp, in_subset, out_name, out, out_subset, dynamic_inputs = node.validate(parent_sdfg, parent_state)
map_lengths = [(e + 1 - b) // s for (b, e, s) in in_subset]

in_shape_collapsed, in_strides_collapsed = collapse_shape_and_strides(in_subset, inp.strides)
out_shape_collapsed, out_strides_collapsed = collapse_shape_and_strides(out_subset, out.strides)

sdfg = dace.SDFG(f"{node.label}_sdfg")
sdfg.add_array(inp_name, in_shape_collapsed, inp.dtype, inp.storage, strides=in_strides_collapsed)
sdfg.add_array(out_name, out_shape_collapsed, out.dtype, out.storage, strides=out_strides_collapsed)

state = sdfg.add_state(f"{node.label}_state", is_start_block=True)

map_lengths = add_dynamic_inputs(dynamic_inputs, sdfg, in_subset, state)

sdfg.schedule = dace.dtypes.ScheduleType.Default

map_params = [f"__i{i}" for i in range(len(map_lengths))]
map_rng = {i: f"0:{s}" for i, s in zip(map_params, map_lengths)}
in_access_expr = ','.join(map_params)
out_access_expr = ','.join(map_params)
inputs = {"_memcpy_inp": dace.memlet.Memlet(f"{inp_name}[{in_access_expr}]")}
outputs = {"_memcpy_out": dace.memlet.Memlet(f"{out_name}[{out_access_expr}]")}
code = "_memcpy_out = _memcpy_inp"
if inp.storage == dace.dtypes.StorageType.GPU_Global:
schedule = dace.dtypes.ScheduleType.GPU_Device
else:
schedule = dace.dtypes.ScheduleType.Default
state.add_mapped_tasklet(f"{node.label}_tasklet",
map_rng,
inputs,
code,
outputs,
schedule=schedule,
external_edges=True)

return sdfg


@library.expansion
class ExpandCUDA(ExpandTransformation):
environments = [environments.CUDA]

@staticmethod
def expansion(node, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG):
inp_name, inp, in_subset, out_name, out, out_subset, dynamic_inputs = node.validate(parent_sdfg, parent_state)

map_lengths = [(e + 1 - b) // s for (b, e, s) in in_subset]
cp_size = reduce(operator.mul, map_lengths, 1)

in_shape_collapsed, in_strides_collapsed = collapse_shape_and_strides(in_subset, inp.strides)
out_shape_collapsed, out_strides_collapsed = collapse_shape_and_strides(out_subset, out.strides)

sdfg = dace.SDFG(f"{node.label}_sdfg")
sdfg.add_array(inp_name, in_shape_collapsed, inp.dtype, inp.storage, strides=in_strides_collapsed)
sdfg.add_array(out_name, out_shape_collapsed, out.dtype, out.storage, strides=out_strides_collapsed)

# Add dynamic inputs
map_lengths = add_dynamic_inputs(dynamic_inputs, sdfg, in_subset, state)

state = sdfg.add_state(f"{node.label}_state")

in_access = state.add_access(inp_name)
out_access = state.add_access(out_name)
tasklet = state.add_tasklet(
name=f"memcpy_tasklet",
inputs={"_memcpy_in"},
outputs={"_memcpy_out"},
code=
f"cudaMemcpyAsync(_memcpy_out, _memcpy_in, {sym2cpp(cp_size)} * sizeof({inp.dtype.ctype}), cudaMemcpyDeviceToDevice, __dace_current_stream);",
language=dace.Language.CPP,
code_global=f"#include <cuda_runtime.h>\n")

tasklet.schedule = dace.dtypes.ScheduleType.GPU_Device

state.add_edge(
in_access, None, tasklet, "_memcpy_in",
dace.memlet.Memlet(data=inp_name, subset=dace.subsets.Range([(0, e - 1, 1) for e in map_lengths])))
state.add_edge(
tasklet, "_memcpy_out", out_access, None,
dace.memlet.Memlet(data=out_name, subset=dace.subsets.Range([(0, e - 1, 1) for e in map_lengths])))

return sdfg


@library.expansion
class ExpandCPU(ExpandTransformation):
environments = [environments.CPU]

@staticmethod
def expansion(node, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG):
inp_name, inp, in_subset, out_name, out, out_subset, dynamic_inputs = node.validate(parent_sdfg, parent_state)
map_lengths = [(e + 1 - b) // s for (b, e, s) in in_subset]
cp_size = reduce(operator.mul, map_lengths, 1)

in_shape_collapsed, in_strides_collapsed = collapse_shape_and_strides(in_subset, inp.strides)
out_shape_collapsed, out_strides_collapsed = collapse_shape_and_strides(out_subset, out.strides)

sdfg = dace.SDFG(f"{node.label}_sdfg")
sdfg.add_array(inp_name, in_shape_collapsed, inp.dtype, inp.storage, strides=in_strides_collapsed)
sdfg.add_array(out_name, out_shape_collapsed, out.dtype, out.storage, strides=out_strides_collapsed)

state = sdfg.add_state(f"{node.label}_state")

# Add dynamic inputs
map_lengths = add_dynamic_inputs(dynamic_inputs, sdfg, in_subset, state)

# Add CPU access nodes
in_access = state.add_access(inp_name)
out_access = state.add_access(out_name)

# Tasklet performing standard CPU memcpy
tasklet = state.add_tasklet(
name=f"memcpy_tasklet",
inputs={"_memcpy_in"},
outputs={"_memcpy_out"},
code=f"memcpy(_memcpy_out, _memcpy_in, {sym2cpp(cp_size)} * sizeof({inp.dtype.ctype}));",
language=dace.Language.CPP,
code_global="#include <cstring>")

# Connect input and output to the tasklet
state.add_edge(
in_access, None, tasklet, "_memcpy_in",
dace.memlet.Memlet(data=inp_name, subset=dace.subsets.Range([(0, e - 1, 1) for e in map_lengths])))
state.add_edge(
tasklet, "_memcpy_out", out_access, None,
dace.memlet.Memlet(data=out_name, subset=dace.subsets.Range([(0, e - 1, 1) for e in map_lengths])))

return sdfg


@library.node
class CopyLibraryNode(nodes.LibraryNode):
implementations = {"pure": ExpandPure, "CUDA": ExpandCUDA, "CPU": ExpandCPU}
default_implementation = 'pure'

def __init__(self, name, *args, **kwargs):
super().__init__(name, *args, **kwargs)

def validate(self, sdfg, state):
"""
Validates the tensor transposition operation.
:return: A tuple (inp, out) for the data descriptors in the parent SDFG.
"""

if len(state.out_edges(self)) != 1:
raise ValueError("Number of out edges unequal to one")

oe = next(iter(state.out_edges(self)))
out = sdfg.arrays[oe.data.data]
out_subset = oe.data.subset
out_name = oe.src_conn

# Add dynamic connectors
dynamic_ies = {ie for ie in state.in_edges(self) if ie.dst_conn != "_in"}
dynamic_inputs = dict()
for ie in dynamic_ies:
dataname = ie.data.data
datadesc = state.sdfg.arrays[dataname]
if not isinstance(datadesc, dace.data.Scalar):
raise ValueError("Dynamic inputs (not connected to `_in`) need to be all scalars")
dynamic_inputs[ie.dst_conn] = datadesc

data_ies = {ie for ie in state.in_edges(self) if ie.dst_conn == "_in"}
if len(data_ies) != 1:
raise ValueError("Only when edge should be to dst connector `_in`")
ie = data_ies.pop()
inp = sdfg.arrays[ie.data.data]

in_subset = ie.data.subset
inp_name = ie.dst_conn
if not inp:
raise ValueError("Missing the input tensor.")
if not out:
raise ValueError("Missing the output tensor.")

if inp.dtype != out.dtype:
raise ValueError("The datatype of the input and output tensors must match.")

if inp.storage != out.storage:
raise ValueError("The storage of the input and output tensors must match.")

return inp_name, inp, in_subset, out_name, out, out_subset, dynamic_inputs
Loading