Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]:
sdfg.save(f'{tmp_dir}/test.sdfg', hash=False)
sdfg2 = SDFG.from_file(f'{tmp_dir}/test.sdfg')
sdfg2.save(f'{tmp_dir}/test2.sdfg', hash=False)
print('Testing SDFG serialization...')
if not filecmp.cmp(f'{tmp_dir}/test.sdfg', f'{tmp_dir}/test2.sdfg'):
with open(f'{tmp_dir}/test.sdfg', 'r') as f1:
with open(f'{tmp_dir}/test2.sdfg', 'r') as f2:
Expand Down
4 changes: 3 additions & 1 deletion dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def _create_einsum_internal(sdfg: SDFG,
for inp, inpname in zip(einsum.inputs, arrays):
inparr = sdfg.arrays[inpname]
if len(inp) != len(inparr.shape):
raise ValueError('Dimensionality mismatch in input "%s"' % inpname)
raise ValueError(f'Dimensionality mismatch in input "{inpname}": '
f'einsum subscript has {len(inp)} dimensions but array has '
f'{len(inparr.shape)} dimensions')
for char, shp in zip(inp, inparr.shape):
if char in chardict and shp != chardict[char]:
raise ValueError('Dimension mismatch in einsum expression')
Expand Down
28 changes: 28 additions & 0 deletions dace/frontend/python/replacements/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,34 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

output_shape = (1, )

elif len(arr1.shape) == 1 and len(arr2.shape) > 2: # vector @ batched matrix (e.g., [k] @ [b, k, n])

res = symbolic.equal(arr1.shape[0], arr2.shape[-2])
if res is None:
warnings.warn(
f'Length of vector {arr1.shape[0]} and second-last dimension of tensor {arr2.shape[-2]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError(f"Length of vector {arr1.shape[0]} must match "
f"second-last dimension of tensor {arr2.shape[-2]}")

# Output has all batch dimensions plus the last dimension of arr2
output_shape = arr2.shape[:-2] + (arr2.shape[-1], )

elif len(arr1.shape) > 2 and len(arr2.shape) == 1: # batched matrix @ vector (e.g., [b, m, k] @ [k])

res = symbolic.equal(arr1.shape[-1], arr2.shape[0])
if res is None:
warnings.warn(
f'Last dimension of tensor {arr1.shape[-1]} and length of vector {arr2.shape[0]} '
f'may not match', UserWarning)
elif not res:
raise SyntaxError(f"Last dimension of tensor {arr1.shape[-1]} must match "
f"length of vector {arr2.shape[0]}")

# Output has all batch dimensions plus the second-last dimension of arr1
output_shape = arr1.shape[:-1]

else: # Dunno what this is, bail

raise SyntaxError("Cannot multiply arrays with shapes: {} and {}".format(arr1.shape, arr2.shape))
Expand Down
20 changes: 17 additions & 3 deletions dace/libraries/blas/blas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,23 @@ def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]:
# | | |
# use these 3 to detect correct option

sAM, sAK = a_strides[-2:]
sBK, sBN = b_strides[-2:]
sCM, sCN = c_strides[-2:]
# Handle 1D inputs by treating them as column/row vectors
# [k] -> treat as [k, 1] with stride [1, k] for column vector
if len(a_strides) == 1:
sAM, sAK = a_strides[0], 1 # Treat as column vector [k, 1]
else:
sAM, sAK = a_strides[-2:]

# Treat as row vector [1, k] -> transposed to [k, 1]
if len(b_strides) == 1:
sBK, sBN = 1, b_strides[0]
else:
sBK, sBN = b_strides[-2:]

if len(c_strides) == 1:
sCM, sCN = c_strides[0], 1
else:
sCM, sCN = c_strides[-2:]

opts = {
'mkm': {
Expand Down
Loading