Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing)
from dace.sdfg import SDFG, utils as sdutils
from dace.data import create_datadescriptor, Data
from dace.sdfg.dealias import dealias

try:
from typing import get_origin, get_args
Expand Down Expand Up @@ -286,6 +287,8 @@ def to_sdfg(self, *args, simplify=None, save=False, validate=False, use_cache=Fa
# Add to cache
self._cache.add(cachekey, sdfg, None)

dealias(sdfg)

return sdfg

def __sdfg__(self, *args, **kwargs) -> SDFG:
Expand Down
153 changes: 153 additions & 0 deletions dace/sdfg/dealias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import copy
from typing import Set, Dict

from dace.sdfg.graph import MultiConnectorEdge
from dace import SDFGState

FULL_VIEW_SUFFIX = "fullview"
SLICE_SUFFIX = "slice"


def _get_new_connector_name(edge: MultiConnectorEdge, repldict: Dict[str, str], other_repldict: Dict[str, str],
state: SDFGState, sdfg: dace.SDFG) -> str:
"""
Determine new connector name for an edge based on data access patterns.
Following the description in the dealias routine
Args:
edge: The edge containing data access information
repldict: Dictionary of existing replacements to avoid name conflicts
state: The SDFG state containing the edge
Returns:
str: New connector name - either the original array name (for full access)
or a unique slice name (for partial access)
"""
arr = state.sdfg.arrays[edge.data.data]
data_shape = arr.shape

# Full subset?
full_range = dace.subsets.Range([(0, dim - 1, 1) for dim in data_shape])
is_complete_subset = edge.data.subset == full_range

combined_repldict = repldict | other_repldict

if is_complete_subset:
candidate_name = edge.data.data
i = 1
while candidate_name in sdfg.arrays or candidate_name in repldict.values():
candidate_name = f"{edge.data.data}_{FULL_VIEW_SUFFIX}_{i}"
i += 1
return candidate_name
else:
i = 1
candidate_name = f"{edge.data.data}_{SLICE_SUFFIX}_{i}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, you can also use data.find_new_name(f'{edge.data.data}_{SLICE_SUFFIX}') to get A_slice followed by A_slice_#

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use find_new_name and remove FULLVIEW_SUFFIX

while candidate_name in combined_repldict.values() or candidate_name in sdfg.arrays:
i += 1
candidate_name = f"{edge.data.data}_{SLICE_SUFFIX}_{i}"
return candidate_name


def dealias(sdfg: dace.SDFG):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe find a better name such as make_unique_nested_data or something

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think of the name def find_readable_connector_names_for_nested_sdfgs(sdfg: dace.SDFG): what do you think?

"""
Remove aliasing in nested SDFG connectors by replacing temporary names with meaningful ones.
Temporary connector names (e.g., tmpxceX) are replaced with names that reflect the actual data
being accessed (e.g. <data_name>_slice_<id> or <data_name>). Depending on applicability
The function handles two main cases:
1. Full array access: A[::] -> connector gets named 'A'
2. Partial array access: A[i:j] -> connector gets named 'A_slice_<id>' <id> is needed in
case multiple slices of the same array are used.
Args:
sdfg (dace.SDFG): Modified in-place.
"""
recurse_in: Set[dace.nodes.NestedSDFG] = set()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the function run recursively?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes, if nested SDFG has more nested SDFGs the function should just continue doing?


for state in sdfg.all_states():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
recurse_in.add(node)

in_edges = state.in_edges(node)
out_edges = state.out_edges(node)

# Gather all replacements we need
# E.g. A[::] -> tmpxceX (NestedSDFG)
# Needs to be replaced with A[::] -> A_slice (NestedSDFG)
# A_slice is chosen if the subset is different than the complete shape A
# Otherwise A is chosen
# Also consider the case where A[i] -> tmp1 (NestedSDFG)
# A[j] -> tmp2
# In this case we need not map them to A twice but to A_slice1, A_slice2
input_repldict = dict()
output_repldict = dict()
for in_edge in in_edges:
# Skip "__return"
if in_edge.data is not None and in_edge.data.data == "__return":
continue
if in_edge.data is not None and in_edge.data.data != in_edge.dst_conn:
new_connector = _get_new_connector_name(in_edge, input_repldict, output_repldict, state,
node.sdfg)
input_repldict[in_edge.dst_conn] = new_connector

for out_edge in out_edges:
if out_edge.data is not None and out_edge.data.data == "__return":
continue
if out_edge.data is not None and out_edge.data.data != out_edge.src_conn:
new_connector = _get_new_connector_name(out_edge, output_repldict, input_repldict, state,
node.sdfg)
output_repldict[out_edge.src_conn] = new_connector

# Replace connectors rm tmpxceX connector with A
for dst_name in set(input_repldict.keys()):
rmed = node.remove_in_connector(dst_name)
assert rmed
for dst_name in set(output_repldict.keys()):
rmed = node.remove_out_connector(dst_name)
assert rmed
for src_name in set(input_repldict.values()):
added = node.add_in_connector(src_name, force=True)
assert added
for src_name in set(output_repldict.values()):
added = node.add_out_connector(src_name, force=True)
assert added

# Update edges
for in_edge in state.in_edges(node):
if in_edge.dst_conn in input_repldict:
state.remove_edge(in_edge)
state.add_edge(in_edge.src, in_edge.src_conn, in_edge.dst, input_repldict[in_edge.dst_conn],
copy.deepcopy(in_edge.data))
for out_edge in state.out_edges(node):
if out_edge.src_conn in output_repldict:
state.remove_edge(out_edge)
state.add_edge(out_edge.src, output_repldict[out_edge.src_conn], out_edge.dst,
out_edge.dst_conn, copy.deepcopy(out_edge.data))

# Replace the data containers
# If data / access nodes are not manually changed before hand
# Dace will try to assign to scalars from a symbolic value and crash the thing
replace_dict = (input_repldict | output_repldict)
print(replace_dict)
added_arrays: Set[str] = set()
for dst_name, src_name in replace_dict.items():
desc: dace.data.Data = node.sdfg.arrays[dst_name]
added_arrays.add(src_name)
if src_name in node.sdfg.arrays:
assert src_name in added_arrays, f"{src_name} is in sdfg.arrays but has not been added by dealias for replacements: {replace_dict}."
else:
node.sdfg.remove_data(dst_name, validate=False)
node.sdfg.add_datadesc(name=src_name, datadesc=desc, find_new_name=False)

# Necessary for DaCe to try assign the value to the missing access node from a tasklet
for inner_state in node.sdfg.all_states():
for inner_node in inner_state.nodes():
if isinstance(inner_node, dace.nodes.AccessNode) and inner_node.data in replace_dict:
inner_node.data = replace_dict[inner_node.data]

node.sdfg.replace_dict(repldict=replace_dict)
10 changes: 5 additions & 5 deletions tests/memlet_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]):
if outer_in.subset[0] != (0, M - 1, 1) or outer_in.subset[1] != (0, N - 1, 1):
raise RuntimeError('Expected subset of outer in memlet to be [0:M, 0:N], found ' + str(outer_in.subset))

inner_in = map_state.edges()[1].data
inner_in = map_state.edges()[2].data
if inner_in.volume != 1:
raise RuntimeError('Expected a volume of 1 on the inner input memlet')
raise RuntimeError(f'Expected a volume of 1 on the inner input memlet, got: {inner_in.volume}')
if inner_in.subset[0] != (i, i, 1) or inner_in.subset[1] != (j, j, 1):
raise RuntimeError('Expected subset of inner in memlet to be [i, j], found ' + str(inner_in.subset))

inner_out = map_state.edges()[2].data
inner_out = map_state.edges()[3].data
if inner_out.volume != 1:
raise RuntimeError('Expected a volume of 1 on the inner output memlet')
raise RuntimeError(f'Expected a volume of 1 on the inner output memlet, got: {inner_out.volume}')
if inner_out.subset[0] != (0, i, 1) or inner_out.subset[1] != (0, N - 1, 1):
raise RuntimeError('Expected subset of inner out memlet to be [0:i+1, 0:N], found ' + str(inner_out.subset))

outer_out = map_state.edges()[3].data
outer_out = map_state.edges()[1].data
if outer_out.volume != M * N:
raise RuntimeError('Expected a volume of M*N on the outer output memlet')
if outer_out.subset[0] != (0, M - 1, 1) or outer_out.subset[1] != (0, N - 1, 1):
Expand Down
Loading