Skip to content

Commit 5ffe768

Browse files
committed
Add 3D-4D broadcasting implementation and test
1 parent d71d388 commit 5ffe768

File tree

3 files changed

+252
-35
lines changed

3 files changed

+252
-35
lines changed

dace/libraries/blas/nodes/batched_matmul.py

Lines changed: 93 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,22 @@ def make_sdfg(node, parent_state, parent_sdfg):
8686
if len(array_a.shape) == 2:
8787
memlet_a = '__im, __ik'
8888
else:
89-
# Use output batch indices
90-
a_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_a.shape) - 2)])
89+
# Align input batch dims to output batch dims
90+
num_a_batch = len(array_a.shape) - 2
91+
# Start from the rightmost batch dimension of output and work backwards
92+
offset = num_batch_dims - num_a_batch
93+
a_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_a_batch)])
9194
memlet_a = f'{a_batch_indices}, __im, __ik'
9295

9396
# For B: if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N]
9497
if len(array_b.shape) == 2:
9598
memlet_b = '__ik, __in'
9699
else:
97-
b_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_b.shape) - 2)])
100+
# Align input batch dims to output batch dims
101+
num_b_batch = len(array_b.shape) - 2
102+
# Start from the rightmost batch dimension of output and work backwards
103+
offset = num_batch_dims - num_b_batch
104+
b_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_b_batch)])
98105
memlet_b = f'{b_batch_indices}, __ik, __in'
99106

100107
# For C: always has batch dimensions
@@ -172,8 +179,11 @@ def expansion(node, state, sdfg):
172179
const {dtype}** __mkl_BMM_B = new const {dtype}*[{BATCH}];
173180
{dtype}** __mkl_BMM_C = new {dtype}*[{BATCH}];
174181
for (int __ib = 0; __ib < {BATCH}; __ib++) {{
175-
__mkl_BMM_A[__ib] = (({dtype}*){x}) + __ib*{stride_a};
176-
__mkl_BMM_B[__ib] = (({dtype}*){y}) + __ib*{stride_b};
182+
// Handle broadcasting - compute correct index for inputs with fewer batch dimensions
183+
int __a_idx = ({stride_a} > 0) ? (({a_batch_size} < {BATCH}) ? (__ib % {a_batch_size}) : __ib) : 0;
184+
int __b_idx = ({stride_b} > 0) ? (({b_batch_size} < {BATCH}) ? (__ib % {b_batch_size}) : __ib) : 0;
185+
__mkl_BMM_A[__ib] = (({dtype}*){x}) + __a_idx*{stride_a};
186+
__mkl_BMM_B[__ib] = (({dtype}*){y}) + __b_idx*{stride_b};
177187
__mkl_BMM_C[__ib] = (({dtype}*)_c) + __ib*{stride_c};
178188
}}
179189
@@ -227,9 +237,12 @@ def expansion(node, state, sdfg):
227237

228238
code = '''
229239
for (int __ib = 0; __ib < {BATCH}; ++__ib) {{
240+
// Handle broadcasting - compute correct index for inputs with fewer batch dimensions
241+
int __a_idx = ({stride_a} > 0) ? (({a_batch_size} < {BATCH}) ? (__ib % {a_batch_size}) : __ib) : 0;
242+
int __b_idx = ({stride_b} > 0) ? (({b_batch_size} < {BATCH}) ? (__ib % {b_batch_size}) : __ib) : 0;
230243
cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha},
231-
(({dtype}*){x}) + __ib*{stride_a}, {lda},
232-
(({dtype}*){y}) + __ib*{stride_b}, {ldb},
244+
(({dtype}*){x}) + __a_idx*{stride_a}, {lda},
245+
(({dtype}*){y}) + __b_idx*{stride_b}, {ldb},
233246
{beta},
234247
(({dtype}*)_c) + __ib*{stride_c}, {ldc});
235248
}}'''.format_map(opt)
@@ -325,17 +338,38 @@ def expansion(node, state, sdfg):
325338
opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func)
326339
opt['array_prefix'] = '_' if needs_copy else ''
327340

341+
# Check if we need broadcasting (non-uniform strides)
342+
needs_broadcasting = (opt.get('a_batch_size') and opt.get('b_batch_size')
343+
and (opt['a_batch_size'] != opt['BATCH'] or opt['b_batch_size'] != opt['BATCH']))
344+
328345
# Matrix multiplication
329346
if (node.compute_type is None and node.accumulator_type is None and node.algorithm is None):
330-
call = '''cublas{func}StridedBatched(__dace_cublas_handle,
331-
CUBLAS_OP_{ta}, CUBLAS_OP_{tb},
332-
{M}, {N}, {K},
333-
{alpha},
334-
({dtype}*){array_prefix}{x}, {lda}, {stride_a},
335-
({dtype}*){array_prefix}{y}, {ldb}, {stride_b},
336-
{beta},
337-
({dtype}*){array_prefix}_c, {ldc}, {stride_c},
338-
{BATCH});'''.format_map(opt)
347+
if needs_broadcasting:
348+
# Use manual loop for broadcasting cases
349+
call = '''
350+
for (int __ib = 0; __ib < {BATCH}; ++__ib) {{
351+
int __a_idx = ({stride_a} > 0) ? (({a_batch_size} < {BATCH}) ? (__ib % {a_batch_size}) : __ib) : 0;
352+
int __b_idx = ({stride_b} > 0) ? (({b_batch_size} < {BATCH}) ? (__ib % {b_batch_size}) : __ib) : 0;
353+
cublas{func}(__dace_cublas_handle,
354+
CUBLAS_OP_{ta}, CUBLAS_OP_{tb},
355+
{M}, {N}, {K},
356+
{alpha},
357+
({dtype}*){array_prefix}{x} + __a_idx*{stride_a}, {lda},
358+
({dtype}*){array_prefix}{y} + __b_idx*{stride_b}, {ldb},
359+
{beta},
360+
({dtype}*){array_prefix}_c + __ib*{stride_c}, {ldc});
361+
}}'''.format_map(opt)
362+
else:
363+
# Use StridedBatched for uniform case
364+
call = '''cublas{func}StridedBatched(__dace_cublas_handle,
365+
CUBLAS_OP_{ta}, CUBLAS_OP_{tb},
366+
{M}, {N}, {K},
367+
{alpha},
368+
({dtype}*){array_prefix}{x}, {lda}, {stride_a},
369+
({dtype}*){array_prefix}{y}, {ldb}, {stride_b},
370+
{beta},
371+
({dtype}*){array_prefix}_c, {ldc}, {stride_c},
372+
{BATCH});'''.format_map(opt)
339373
else:
340374
if node.compute_type is not None:
341375
acctype = node.compute_type
@@ -349,24 +383,49 @@ def expansion(node, state, sdfg):
349383
if node.algorithm is not None:
350384
algorithm = node.algorithm
351385

352-
call = f'''
353-
cublasGemmStridedBatchedEx(__dace_cublas_handle,
354-
CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']},
355-
{opt['M']}, {opt['N']}, {opt['K']},
356-
{alpha},
357-
{opt['array_prefix']}{opt['x']},
358-
{dtype_to_cudadatatype(opt['xdtype'])},
359-
{opt['lda']}, {opt['stride_a']},
360-
{opt['array_prefix']}{opt['y']},
361-
{dtype_to_cudadatatype(opt['ydtype'])},
362-
{opt['ldb']}, {opt['stride_b']},
363-
{beta},
364-
{opt['array_prefix']}_c,
365-
{dtype_to_cudadatatype(opt['cdtype'])},
366-
{opt['ldc']}, {opt['stride_c']},
367-
{opt['BATCH']},
368-
{acctype}, {algorithm});
369-
'''
386+
if needs_broadcasting:
387+
# Use manual loop for broadcasting cases with GemmEx
388+
call = f'''
389+
for (int __ib = 0; __ib < {opt['BATCH']}; ++__ib) {{{{
390+
int __a_idx = ({opt['stride_a']} > 0) ? (({opt['a_batch_size']} < {opt['BATCH']}) ? (__ib % {opt['a_batch_size']}) : __ib) : 0;
391+
int __b_idx = ({opt['stride_b']} > 0) ? (({opt['b_batch_size']} < {opt['BATCH']}) ? (__ib % {opt['b_batch_size']}) : __ib) : 0;
392+
cublasGemmEx(__dace_cublas_handle,
393+
CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']},
394+
{opt['M']}, {opt['N']}, {opt['K']},
395+
{alpha},
396+
{opt['array_prefix']}{opt['x']} + __a_idx*{opt['stride_a']},
397+
{dtype_to_cudadatatype(opt['xdtype'])},
398+
{opt['lda']},
399+
{opt['array_prefix']}{opt['y']} + __b_idx*{opt['stride_b']},
400+
{dtype_to_cudadatatype(opt['ydtype'])},
401+
{opt['ldb']},
402+
{beta},
403+
{opt['array_prefix']}_c + __ib*{opt['stride_c']},
404+
{dtype_to_cudadatatype(opt['cdtype'])},
405+
{opt['ldc']},
406+
{acctype}, {algorithm});
407+
}}}}
408+
'''
409+
else:
410+
# Use StridedBatchedEx for uniform case
411+
call = f'''
412+
cublasGemmStridedBatchedEx(__dace_cublas_handle,
413+
CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']},
414+
{opt['M']}, {opt['N']}, {opt['K']},
415+
{alpha},
416+
{opt['array_prefix']}{opt['x']},
417+
{dtype_to_cudadatatype(opt['xdtype'])},
418+
{opt['lda']}, {opt['stride_a']},
419+
{opt['array_prefix']}{opt['y']},
420+
{dtype_to_cudadatatype(opt['ydtype'])},
421+
{opt['ldb']}, {opt['stride_b']},
422+
{beta},
423+
{opt['array_prefix']}_c,
424+
{dtype_to_cudadatatype(opt['cdtype'])},
425+
{opt['ldc']}, {opt['stride_c']},
426+
{opt['BATCH']},
427+
{acctype}, {algorithm});
428+
'''
370429

371430
code = call_prefix + call + call_suffix
372431
tasklet = dace.sdfg.nodes.Tasklet(node.name,

dace/libraries/blas/nodes/matmul.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def _get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, c_strides
104104
# Calculate strides for batched operations
105105
# For a tensor with shape [B1, B2, ..., M, K], the stride for batched operations
106106
# should be M*K (the size of each matrix) to iterate through all matrices in the flattened batch
107+
#
108+
# For broadcasting cases (e.g., A - [b1, b2, m, k] @ B - [b2, k, n]):
109+
# - The flattened batch is b1*b2
110+
# - B needs special handling: we need to compute which of the b2 matrices to use
111+
# For batch index i in [0, b1*b2), the B matrix index is (i % b2)
112+
# This can be expressed as: if A has more batch dims than B, use modulo arithmetic
113+
107114
stride_a = 0
108115
stride_b = 0
109116
stride_c = 0
@@ -125,10 +132,35 @@ def _get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, c_strides
125132
if res is False:
126133
raise ValueError(f'Output batch dimension mismatch: {c_dim} vs {r_dim} at position {i}')
127134

135+
# For partial broadcasting (3D-4D cases), we need to track additional information
136+
# to properly index into the smaller batch dimension tensor
137+
a_batch_multiplier = 1 # How many times to cycle through A's batch
138+
b_batch_multiplier = 1 # How many times to cycle through B's batch
139+
140+
if len(a_batch_dims) < len(result_batch_dims):
141+
# A has fewer batch dimensions, so it will be broadcast
142+
# Calculate the size of the leading dimensions that A doesn't have
143+
a_batch_multiplier = prod(result_batch_dims[:len(result_batch_dims) - len(a_batch_dims)])
144+
145+
if len(b_batch_dims) < len(result_batch_dims):
146+
# B has fewer batch dimensions, so it will be broadcast
147+
# Calculate the size of the leading dimensions that B doesn't have
148+
b_batch_multiplier = prod(result_batch_dims[:len(result_batch_dims) - len(b_batch_dims)])
149+
128150
if batch_size == 1 and not result_batch_dims:
129151
return {}
130152

131-
return {'sa': stride_a, 'sb': stride_b, 'sc': stride_c, 'b': batch_size, 'batch_dims': result_batch_dims}
153+
return {
154+
'sa': stride_a,
155+
'sb': stride_b,
156+
'sc': stride_c,
157+
'b': batch_size,
158+
'batch_dims': result_batch_dims,
159+
'a_batch_size': prod(a_batch_dims) if a_batch_dims else 1,
160+
'b_batch_size': prod(b_batch_dims) if b_batch_dims else 1,
161+
'a_batch_multiplier': a_batch_multiplier,
162+
'b_batch_multiplier': b_batch_multiplier
163+
}
132164

133165

134166
def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) -> Dict[str, Any]:
@@ -165,6 +197,7 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta,
165197
if opt['swap']:
166198
if bopt:
167199
bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa']
200+
bopt['a_batch_size'], bopt['b_batch_size'] = bopt['b_batch_size'], bopt['a_batch_size']
168201
opt['lda'], opt['ldb'] = opt['ldb'], opt['lda']
169202
opt['x'], opt['y'] = opt['y'], opt['x']
170203
opt['xdtype'], opt['ydtype'] = opt['ydtype'], opt['xdtype']
@@ -180,6 +213,8 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta,
180213
opt['stride_b'] = sym2cpp(bopt['sb'])
181214
opt['stride_c'] = sym2cpp(bopt['sc'])
182215
opt['BATCH'] = sym2cpp(bopt['b'])
216+
opt['a_batch_size'] = sym2cpp(bopt['a_batch_size'])
217+
opt['b_batch_size'] = sym2cpp(bopt['b_batch_size'])
183218
else:
184219
opt['BATCH'] = None
185220

tests/library/batched_matmul_test.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,111 @@ def bmm_4d_broadcast(A: dtype[m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m,
224224
assert np.allclose(ref, z)
225225

226226

227+
@pytest.mark.parametrize("implementation, dtype", [
228+
pytest.param("pure", dace.float32),
229+
pytest.param("pure", dace.float64),
230+
pytest.param("MKL", dace.float32, marks=pytest.mark.mkl),
231+
pytest.param("MKL", dace.float64, marks=pytest.mark.mkl),
232+
pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu),
233+
pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu),
234+
pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack),
235+
pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack)
236+
])
237+
def test_batchmm_3d_4d_broadcast(implementation: str, dtype):
238+
"""Test 4D batched matmul with broadcast on LHS: [b2, m, k] @ [b1, b2, k, n]"""
239+
b1, b2, m, n, k = 4, 2, 64, 128, 64
240+
241+
@dace.program
242+
def bmm_3d_4d_broadcast(A: dtype[b2, m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m, n]):
243+
C[:] = A @ B
244+
245+
with change_default(blas, implementation):
246+
sdfg = bmm_3d_4d_broadcast.to_sdfg()
247+
sdfg.simplify()
248+
sdfg.expand_library_nodes()
249+
250+
x = np.random.rand(b2, m, k).astype(dtype.as_numpy_dtype())
251+
y = np.random.rand(b1, b2, k, n).astype(dtype.as_numpy_dtype())
252+
z = np.zeros([b1, b2, m, n]).astype(dtype.as_numpy_dtype())
253+
254+
csdfg = sdfg.compile()
255+
csdfg(A=x, B=y, C=z)
256+
257+
ref = x @ y
258+
259+
assert np.allclose(ref, z)
260+
261+
262+
@pytest.mark.parametrize("implementation, dtype", [
263+
pytest.param("pure", dace.float32),
264+
pytest.param("pure", dace.float64),
265+
pytest.param("MKL", dace.float32, marks=pytest.mark.mkl),
266+
pytest.param("MKL", dace.float64, marks=pytest.mark.mkl),
267+
pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu),
268+
pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu),
269+
pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack),
270+
pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack)
271+
])
272+
def test_batchmm_4d_3d_broadcast(implementation: str, dtype):
273+
"""Test 4D batched matmul with broadcast on RHS: [b1, b2, m, k] @ [b2, k, n]"""
274+
b1, b2, m, n, k = 4, 2, 64, 128, 64
275+
276+
@dace.program
277+
def bmm_4d_3d_broadcast(A: dtype[b1, b2, m, k], B: dtype[b2, k, n], C: dtype[b1, b2, m, n]):
278+
C[:] = A @ B
279+
280+
with change_default(blas, implementation):
281+
sdfg = bmm_4d_3d_broadcast.to_sdfg()
282+
sdfg.simplify()
283+
sdfg.expand_library_nodes()
284+
285+
x = np.random.rand(b1, b2, m, k).astype(dtype.as_numpy_dtype())
286+
y = np.random.rand(b2, k, n).astype(dtype.as_numpy_dtype())
287+
z = np.zeros([b1, b2, m, n]).astype(dtype.as_numpy_dtype())
288+
289+
csdfg = sdfg.compile()
290+
csdfg(A=x, B=y, C=z)
291+
292+
ref = x @ y
293+
294+
assert np.allclose(ref, z)
295+
296+
297+
@pytest.mark.parametrize("implementation, dtype", [
298+
pytest.param("pure", dace.float32),
299+
pytest.param("pure", dace.float64),
300+
pytest.param("MKL", dace.float32, marks=pytest.mark.mkl),
301+
pytest.param("MKL", dace.float64, marks=pytest.mark.mkl),
302+
pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu),
303+
pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu),
304+
pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack),
305+
pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack)
306+
])
307+
def test_batchmm_5d_3d_broadcast(implementation: str, dtype):
308+
"""Test 5D batched matmul with broadcast on RHS: [b1, b2, b3, m, k] @ [b3, k, n]"""
309+
b1, b2, b3, m, n, k = 4, 2, 3, 64, 128, 64
310+
311+
@dace.program
312+
def bmm_5d_3d_broadcast(A: dtype[b1, b2, b3, m, k], B: dtype[b3, k, n], C: dtype[b1, b2, b3, m, n]):
313+
C[:] = A @ B
314+
315+
with change_default(blas, implementation):
316+
sdfg = bmm_5d_3d_broadcast.to_sdfg()
317+
sdfg.simplify()
318+
sdfg.expand_library_nodes()
319+
320+
x = np.random.rand(b1, b2, b3, m, k).astype(dtype.as_numpy_dtype())
321+
y = np.random.rand(b3, k, n).astype(dtype.as_numpy_dtype())
322+
z = np.zeros([b1, b2, b3, m, n]).astype(dtype.as_numpy_dtype())
323+
324+
csdfg = sdfg.compile()
325+
csdfg(A=x, B=y, C=z)
326+
327+
ref = x @ y
328+
329+
assert np.allclose(ref, z)
330+
331+
227332
if __name__ == "__main__":
228333
test_batchmm("pure", dace.float32)
229334
test_batchmm("pure", dace.float64)
@@ -261,3 +366,21 @@ def bmm_4d_broadcast(A: dtype[m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m,
261366
test_batchmm_4d_broadcast_lhs("MKL", dace.float64)
262367
test_batchmm_4d_broadcast_lhs("cuBLAS", dace.float32)
263368
test_batchmm_4d_broadcast_lhs("cuBLAS", dace.float64)
369+
test_batchmm_3d_4d_broadcast("pure", dace.float32)
370+
test_batchmm_3d_4d_broadcast("pure", dace.float64)
371+
test_batchmm_3d_4d_broadcast("MKL", dace.float32)
372+
test_batchmm_3d_4d_broadcast("MKL", dace.float64)
373+
test_batchmm_3d_4d_broadcast("cuBLAS", dace.float32)
374+
test_batchmm_3d_4d_broadcast("cuBLAS", dace.float64)
375+
test_batchmm_4d_3d_broadcast("pure", dace.float32)
376+
test_batchmm_4d_3d_broadcast("pure", dace.float64)
377+
test_batchmm_4d_3d_broadcast("MKL", dace.float32)
378+
test_batchmm_4d_3d_broadcast("MKL", dace.float64)
379+
test_batchmm_4d_3d_broadcast("cuBLAS", dace.float32)
380+
test_batchmm_4d_3d_broadcast("cuBLAS", dace.float64)
381+
test_batchmm_5d_3d_broadcast("pure", dace.float32)
382+
test_batchmm_5d_3d_broadcast("pure", dace.float64)
383+
test_batchmm_5d_3d_broadcast("MKL", dace.float32)
384+
test_batchmm_5d_3d_broadcast("MKL", dace.float64)
385+
test_batchmm_5d_3d_broadcast("cuBLAS", dace.float32)
386+
test_batchmm_5d_3d_broadcast("cuBLAS", dace.float64)

0 commit comments

Comments
 (0)