Skip to content

Commit 3a88e77

Browse files
committed
Add support for 1D batched-GEMV broadcasting
1 parent 5ffe768 commit 3a88e77

File tree

5 files changed

+511
-44
lines changed

5 files changed

+511
-44
lines changed

dace/frontend/python/replacements/linalg.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,34 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op
8686

8787
output_shape = (1, )
8888

89+
elif len(arr1.shape) == 1 and len(arr2.shape) > 2: # vector @ batched matrix (e.g., [k] @ [b, k, n])
90+
91+
res = symbolic.equal(arr1.shape[0], arr2.shape[-2])
92+
if res is None:
93+
warnings.warn(
94+
f'Length of vector {arr1.shape[0]} and second-last dimension of tensor {arr2.shape[-2]} '
95+
f'may not match', UserWarning)
96+
elif not res:
97+
raise SyntaxError(f"Length of vector {arr1.shape[0]} must match "
98+
f"second-last dimension of tensor {arr2.shape[-2]}")
99+
100+
# Output has all batch dimensions plus the last dimension of arr2
101+
output_shape = arr2.shape[:-2] + (arr2.shape[-1], )
102+
103+
elif len(arr1.shape) > 2 and len(arr2.shape) == 1: # batched matrix @ vector (e.g., [b, m, k] @ [k])
104+
105+
res = symbolic.equal(arr1.shape[-1], arr2.shape[0])
106+
if res is None:
107+
warnings.warn(
108+
f'Last dimension of tensor {arr1.shape[-1]} and length of vector {arr2.shape[0]} '
109+
f'may not match', UserWarning)
110+
elif not res:
111+
raise SyntaxError(f"Last dimension of tensor {arr1.shape[-1]} must match "
112+
f"length of vector {arr2.shape[0]}")
113+
114+
# Output has all batch dimensions plus the second-last dimension of arr1
115+
output_shape = arr1.shape[:-1]
116+
89117
else: # Dunno what this is, bail
90118

91119
raise SyntaxError("Cannot multiply arrays with shapes: {} and {}".format(arr1.shape, arr2.shape))

dace/libraries/blas/blas_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,23 @@ def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]:
110110
# | | |
111111
# use these 3 to detect correct option
112112

113-
sAM, sAK = a_strides[-2:]
114-
sBK, sBN = b_strides[-2:]
115-
sCM, sCN = c_strides[-2:]
113+
# Handle 1D inputs by treating them as column/row vectors
114+
# [k] -> treat as [k, 1] with stride [1, k] for column vector
115+
if len(a_strides) == 1:
116+
sAM, sAK = a_strides[0], 1 # Treat as column vector [k, 1]
117+
else:
118+
sAM, sAK = a_strides[-2:]
119+
120+
# Treat as row vector [1, k] -> transposed to [k, 1]
121+
if len(b_strides) == 1:
122+
sBK, sBN = 1, b_strides[0]
123+
else:
124+
sBK, sBN = b_strides[-2:]
125+
126+
if len(c_strides) == 1:
127+
sCM, sCN = c_strides[0], 1
128+
else:
129+
sCM, sCN = c_strides[-2:]
116130

117131
opts = {
118132
'mkm': {

0 commit comments

Comments
 (0)