Skip to content

Commit bf97502

Browse files
committed
Fix and test MatMul accumulation
1 parent 3a88e77 commit bf97502

File tree

4 files changed

+455
-58
lines changed

4 files changed

+455
-58
lines changed

dace/libraries/blas/nodes/batched_matmul.py

Lines changed: 151 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,35 @@ def make_sdfg(node, parent_state, parent_sdfg):
7171
_, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=storage)
7272
_, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-3], storage=storage)
7373

74-
# Add an initialization state
75-
init_state = sdfg.add_state()
76-
init_state.add_mapped_tasklet(
77-
'batched_matmul_init', {
78-
'_o%d' % i: '0:%s' % symstr(d)
79-
for i, d in enumerate(shape_c)
80-
}, {},
81-
'out = 0', {'out': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))},
82-
external_edges=True)
83-
84-
state = sdfg.add_state_after(init_state, node.label + "_state")
74+
# Handle beta factor for C
75+
# C_new = alpha * A @ B + beta * C_old
76+
if node.beta == 0:
77+
# Initialize C to 0
78+
init_state = sdfg.add_state()
79+
init_state.add_mapped_tasklet(
80+
'batched_matmul_init', {
81+
'_o%d' % i: '0:%s' % symstr(d)
82+
for i, d in enumerate(shape_c)
83+
}, {},
84+
'out = 0', {'out': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))},
85+
external_edges=True)
86+
state = sdfg.add_state_after(init_state, node.label + "_state")
87+
elif node.beta != 1:
88+
# Scale C by beta before accumulation
89+
init_state = sdfg.add_state()
90+
beta_value = node.beta
91+
init_state.add_mapped_tasklet(
92+
'batched_matmul_scale_c', {
93+
'_o%d' % i: '0:%s' % symstr(d)
94+
for i, d in enumerate(shape_c)
95+
}, {'_in': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))},
96+
f'_out = {beta_value} * _in',
97+
{'_out': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))},
98+
external_edges=True)
99+
state = sdfg.add_state_after(init_state, node.label + "_state")
100+
else:
101+
# beta == 1: Just accumulate into existing C values
102+
state = sdfg.add_state(node.label + "_state")
85103

86104
# Calculate number of batch dimensions in output
87105
# For 1D cases, output may have fewer dimensions
@@ -168,12 +186,19 @@ def make_sdfg(node, parent_state, parent_sdfg):
168186
c_indices_parts.append('__in')
169187
c_indices = ', '.join(c_indices_parts)
170188

189+
# Handle alpha factor in the multiplication
190+
alpha_value = node.alpha
191+
if alpha_value == 1:
192+
tasklet_code = '__c = __a * __b'
193+
else:
194+
tasklet_code = f'__c = {alpha_value} * __a * __b'
195+
171196
state.add_mapped_tasklet('_BatchedMatMult_',
172197
map_params, {
173198
'__a': dace.Memlet.simple("_a", memlet_a),
174199
'__b': dace.Memlet.simple("_b", memlet_b)
175200
},
176-
'__c = __a * __b',
201+
tasklet_code,
177202
{'__c': dace.Memlet.simple("_c", c_indices, wcr_str='lambda x, y: x + y')},
178203
external_edges=True)
179204

@@ -197,18 +222,31 @@ def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, as
197222
from dace.codegen.common import sym2cpp
198223

199224
prefix = to_blastype(dtype.type).lower()
225+
# Use node's alpha and beta values
200226
if dtype == dace.float32:
201-
alpha = "1.0f"
202-
beta = "0.0f"
227+
alpha = f"{float(node.alpha)}f"
228+
beta = f"{float(node.beta)}f"
203229
elif dtype == dace.float64:
204-
alpha = "1.0"
205-
beta = "0.0"
230+
alpha = f"{float(node.alpha)}"
231+
beta = f"{float(node.beta)}"
206232
elif dtype == dace.complex64:
207-
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
208-
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
233+
if node.alpha == 1:
234+
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
235+
else:
236+
alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)"
237+
if node.beta == 0:
238+
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
239+
else:
240+
beta = f"dace::blas::make_cuComplex({node.beta}, 0)"
209241
elif dtype == dace.complex128:
210-
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
211-
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
242+
if node.alpha == 1:
243+
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
244+
else:
245+
alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)"
246+
if node.beta == 0:
247+
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
248+
else:
249+
beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)"
212250
else:
213251
raise ValueError("Unsupported type for BLAS: " + str(dtype))
214252

@@ -318,21 +356,35 @@ def expansion(node, state, sdfg):
318356
astrides, bstrides, dtype, is_a_1d, is_b_1d)
319357

320358
func = to_blastype(dtype.type).lower() + 'gemm'
359+
360+
# Use node's alpha and beta values
321361
if dtype == dace.float32:
322-
alpha = "1.0f"
323-
beta = "0.0f"
362+
alpha = f"{float(node.alpha)}f"
363+
beta = f"{float(node.beta)}f"
324364
prefix = "s"
325365
elif dtype == dace.float64:
326-
alpha = "1.0"
327-
beta = "0.0"
366+
alpha = f"{float(node.alpha)}"
367+
beta = f"{float(node.beta)}"
328368
prefix = "d"
329369
elif dtype == dace.complex64:
330-
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
331-
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
370+
if node.alpha == 1:
371+
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
372+
else:
373+
alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)"
374+
if node.beta == 0:
375+
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
376+
else:
377+
beta = f"dace::blas::make_cuComplex({node.beta}, 0)"
332378
prefix = "c"
333379
elif dtype == dace.complex128:
334-
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
335-
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
380+
if node.alpha == 1:
381+
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
382+
else:
383+
alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)"
384+
if node.beta == 0:
385+
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
386+
else:
387+
beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)"
336388
prefix = "z"
337389
else:
338390
raise ValueError("Unsupported type for BLAS dot product: " + str(dtype))
@@ -393,18 +445,31 @@ def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, as
393445
from dace.codegen.common import sym2cpp
394446

395447
prefix = to_blastype(dtype.type).lower()
448+
# Use node's alpha and beta values
396449
if dtype == dace.float32:
397-
alpha = "1.0f"
398-
beta = "0.0f"
450+
alpha = f"{float(node.alpha)}f"
451+
beta = f"{float(node.beta)}f"
399452
elif dtype == dace.float64:
400-
alpha = "1.0"
401-
beta = "0.0"
453+
alpha = f"{float(node.alpha)}"
454+
beta = f"{float(node.beta)}"
402455
elif dtype == dace.complex64:
403-
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
404-
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
456+
if node.alpha == 1:
457+
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
458+
else:
459+
alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)"
460+
if node.beta == 0:
461+
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
462+
else:
463+
beta = f"dace::blas::make_cuComplex({node.beta}, 0)"
405464
elif dtype == dace.complex128:
406-
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
407-
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
465+
if node.alpha == 1:
466+
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
467+
else:
468+
alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)"
469+
if node.beta == 0:
470+
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
471+
else:
472+
beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)"
408473
else:
409474
raise ValueError("Unsupported type for BLAS: " + str(dtype))
410475

@@ -514,18 +579,31 @@ def expansion(node, state, sdfg):
514579
astrides, bstrides, dtype, is_a_1d, is_b_1d)
515580

516581
func = to_blastype(dtype.type).lower() + 'gemm'
582+
# Use node's alpha and beta values
517583
if dtype == dace.float32:
518-
alpha = "1.0f"
519-
beta = "0.0f"
584+
alpha = f"{float(node.alpha)}f"
585+
beta = f"{float(node.beta)}f"
520586
elif dtype == dace.float64:
521-
alpha = "1.0"
522-
beta = "0.0"
587+
alpha = f"{float(node.alpha)}"
588+
beta = f"{float(node.beta)}"
523589
elif dtype == dace.complex64:
524-
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
525-
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
590+
if node.alpha == 1:
591+
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
592+
else:
593+
alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)"
594+
if node.beta == 0:
595+
beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
596+
else:
597+
beta = f"dace::blas::make_cuComplex({node.beta}, 0)"
526598
elif dtype == dace.complex128:
527-
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
528-
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
599+
if node.alpha == 1:
600+
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
601+
else:
602+
alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)"
603+
if node.beta == 0:
604+
beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
605+
else:
606+
beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)"
529607
else:
530608
raise ValueError("Unsupported type for BLAS dot product: " + str(dtype))
531609
opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func)
@@ -612,26 +690,43 @@ def expansion(node, state, sdfg):
612690
1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()",
613691
0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()",
614692
}
693+
694+
# Handle alpha
615695
if node.alpha not in constants:
616696
# Deal with complex input constants
617697
if isinstance(node.alpha, complex):
618-
alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})'
698+
alpha_val = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})'
619699
else:
620-
alpha = f'{dtype.ctype}({node.alpha})'
700+
alpha_val = f'{dtype.ctype}({node.alpha})'
701+
use_host_mode_alpha = True
702+
else:
703+
alpha = constants[node.alpha]
704+
use_host_mode_alpha = False
621705

622-
# Set pointer mode to host
623-
call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);
624-
{dtype.ctype} alpha = {alpha};
625-
{dtype.ctype} beta = 0;
626-
'''
706+
# Handle beta
707+
if node.beta not in constants:
708+
# Deal with complex input constants
709+
if isinstance(node.beta, complex):
710+
beta_val = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})'
711+
else:
712+
beta_val = f'{dtype.ctype}({node.beta})'
713+
use_host_mode_beta = True
714+
else:
715+
beta = constants[node.beta]
716+
use_host_mode_beta = False
717+
718+
# Set pointer mode to host if needed
719+
if use_host_mode_alpha or use_host_mode_beta:
720+
call_prefix += 'cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);\n'
721+
if use_host_mode_alpha:
722+
call_prefix += f' {dtype.ctype} alpha = {alpha_val};\n'
723+
alpha = f'({cdtype} *)&alpha'
724+
if use_host_mode_beta:
725+
call_prefix += f' {dtype.ctype} beta = {beta_val};\n'
726+
beta = f'({cdtype} *)&beta'
627727
call_suffix += '''
628728
cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
629729
'''
630-
beta = f'({cdtype} *)&beta'
631-
alpha = f'({cdtype} *)&alpha'
632-
else:
633-
alpha = constants[node.alpha]
634-
beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort
635730

636731
# Set up options for code formatting
637732
opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func)

dace/libraries/blas/nodes/gemm.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,32 @@ def expansion(node, state, sdfg):
192192
opt['alpha'] = '&__alpha'
193193
opt['beta'] = '&__beta'
194194

195+
# Handle the case when cin=True and beta != 0 (node has _c as both input and output)
196+
# Since BLAS GEMM does in-place read-modify-write on C, and tasklets cannot have
197+
# duplicate connectors, we remove the input _c connector. The BLAS call will read
198+
# from and write to the same memory location (_c output).
199+
#
200+
# We also remove the incoming edge and orphaned access node to maintain graph validity.
201+
in_connectors = {}
202+
for k, v in node.in_connectors.items():
203+
if k == '_c':
204+
# Remove the incoming edge to _c and the source access node if it becomes isolated
205+
for edge in list(state.in_edges_by_connector(node, '_c')):
206+
src_node = edge.src
207+
state.remove_edge(edge)
208+
# Remove the access node if it has no other edges
209+
if state.degree(src_node) == 0:
210+
state.remove_node(src_node)
211+
else:
212+
in_connectors[k] = v
213+
195214
code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, "
196215
"{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, "
197216
"_c, {ldc});").format_map(opt)
198217

199218
tasklet = dace.sdfg.nodes.Tasklet(
200219
node.name,
201-
node.in_connectors,
220+
in_connectors,
202221
node.out_connectors,
203222
code,
204223
language=dace.dtypes.Language.CPP,

dace/libraries/blas/nodes/gemv.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,31 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
786786
code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, _A, {lda},
787787
_x, {strides_x[0]}, {beta}, _y, {strides_y[0]});"""
788788

789+
# Handle the case when beta != 0 (node has _y as both input and output)
790+
# NOTE: This happens when the Gemv node is created with beta != 0 (see __init__ line 915).
791+
# The pure implementation needs y as input to explicitly scale it, but BLAS implementations
792+
# handle this internally.
793+
#
794+
# Since BLAS GEMV does in-place read-modify-write on y, and tasklets cannot have
795+
# duplicate connectors, we remove the input _y connector. The BLAS call will read
796+
# from and write to the same memory location (_y output).
797+
#
798+
# We also remove the incoming edge and orphaned access node to maintain graph validity.
799+
in_connectors = {}
800+
for k, v in node.in_connectors.items():
801+
if k == '_y':
802+
# Remove the incoming edge to _y and the source access node if it becomes isolated
803+
for edge in list(state.in_edges_by_connector(node, '_y')):
804+
src_node = edge.src
805+
state.remove_edge(edge)
806+
# Remove the access node if it has no other edges
807+
if state.degree(src_node) == 0:
808+
state.remove_node(src_node)
809+
else:
810+
in_connectors[k] = v
811+
789812
tasklet = dace.sdfg.nodes.Tasklet(node.name,
790-
node.in_connectors,
813+
in_connectors,
791814
node.out_connectors,
792815
code,
793816
language=dace.dtypes.Language.CPP)

0 commit comments

Comments
 (0)