@@ -71,17 +71,35 @@ def make_sdfg(node, parent_state, parent_sdfg):
7171 _ , array_b = sdfg .add_array ("_b" , shape_b , dtype_b , strides = strides_b , storage = storage )
7272 _ , array_c = sdfg .add_array ("_c" , shape_c , dtype_c , strides = cdata [- 3 ], storage = storage )
7373
74- # Add an initialization state
75- init_state = sdfg .add_state ()
76- init_state .add_mapped_tasklet (
77- 'batched_matmul_init' , {
78- '_o%d' % i : '0:%s' % symstr (d )
79- for i , d in enumerate (shape_c )
80- }, {},
81- 'out = 0' , {'out' : dace .Memlet .simple ('_c' , ',' .join (['_o%d' % i for i in range (len (shape_c ))]))},
82- external_edges = True )
83-
84- state = sdfg .add_state_after (init_state , node .label + "_state" )
74+ # Handle beta factor for C
75+ # C_new = alpha * A @ B + beta * C_old
76+ if node .beta == 0 :
77+ # Initialize C to 0
78+ init_state = sdfg .add_state ()
79+ init_state .add_mapped_tasklet (
80+ 'batched_matmul_init' , {
81+ '_o%d' % i : '0:%s' % symstr (d )
82+ for i , d in enumerate (shape_c )
83+ }, {},
84+ 'out = 0' , {'out' : dace .Memlet .simple ('_c' , ',' .join (['_o%d' % i for i in range (len (shape_c ))]))},
85+ external_edges = True )
86+ state = sdfg .add_state_after (init_state , node .label + "_state" )
87+ elif node .beta != 1 :
88+ # Scale C by beta before accumulation
89+ init_state = sdfg .add_state ()
90+ beta_value = node .beta
91+ init_state .add_mapped_tasklet (
92+ 'batched_matmul_scale_c' , {
93+ '_o%d' % i : '0:%s' % symstr (d )
94+ for i , d in enumerate (shape_c )
95+ }, {'_in' : dace .Memlet .simple ('_c' , ',' .join (['_o%d' % i for i in range (len (shape_c ))]))},
96+ f'_out = { beta_value } * _in' ,
97+ {'_out' : dace .Memlet .simple ('_c' , ',' .join (['_o%d' % i for i in range (len (shape_c ))]))},
98+ external_edges = True )
99+ state = sdfg .add_state_after (init_state , node .label + "_state" )
100+ else :
101+ # beta == 1: Just accumulate into existing C values
102+ state = sdfg .add_state (node .label + "_state" )
85103
86104 # Calculate number of batch dimensions in output
87105 # For 1D cases, output may have fewer dimensions
@@ -168,12 +186,19 @@ def make_sdfg(node, parent_state, parent_sdfg):
168186 c_indices_parts .append ('__in' )
169187 c_indices = ', ' .join (c_indices_parts )
170188
189+ # Handle alpha factor in the multiplication
190+ alpha_value = node .alpha
191+ if alpha_value == 1 :
192+ tasklet_code = '__c = __a * __b'
193+ else :
194+ tasklet_code = f'__c = { alpha_value } * __a * __b'
195+
171196 state .add_mapped_tasklet ('_BatchedMatMult_' ,
172197 map_params , {
173198 '__a' : dace .Memlet .simple ("_a" , memlet_a ),
174199 '__b' : dace .Memlet .simple ("_b" , memlet_b )
175200 },
176- '__c = __a * __b' ,
201+ tasklet_code ,
177202 {'__c' : dace .Memlet .simple ("_c" , c_indices , wcr_str = 'lambda x, y: x + y' )},
178203 external_edges = True )
179204
@@ -197,18 +222,31 @@ def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, as
197222 from dace .codegen .common import sym2cpp
198223
199224 prefix = to_blastype (dtype .type ).lower ()
225+ # Use node's alpha and beta values
200226 if dtype == dace .float32 :
201- alpha = "1.0f "
202- beta = "0.0f "
227+ alpha = f" { float ( node . alpha ) } f "
228+ beta = f" { float ( node . beta ) } f "
203229 elif dtype == dace .float64 :
204- alpha = "1.0 "
205- beta = "0.0 "
230+ alpha = f" { float ( node . alpha ) } "
231+ beta = f" { float ( node . beta ) } "
206232 elif dtype == dace .complex64 :
207- alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
208- beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
233+ if node .alpha == 1 :
234+ alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
235+ else :
236+ alpha = f"dace::blas::make_cuComplex({ node .alpha } , 0)"
237+ if node .beta == 0 :
238+ beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
239+ else :
240+ beta = f"dace::blas::make_cuComplex({ node .beta } , 0)"
209241 elif dtype == dace .complex128 :
210- alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
211- beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
242+ if node .alpha == 1 :
243+ alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
244+ else :
245+ alpha = f"dace::blas::make_cuDoubleComplex({ node .alpha } , 0)"
246+ if node .beta == 0 :
247+ beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
248+ else :
249+ beta = f"dace::blas::make_cuDoubleComplex({ node .beta } , 0)"
212250 else :
213251 raise ValueError ("Unsupported type for BLAS: " + str (dtype ))
214252
@@ -318,21 +356,35 @@ def expansion(node, state, sdfg):
318356 astrides , bstrides , dtype , is_a_1d , is_b_1d )
319357
320358 func = to_blastype (dtype .type ).lower () + 'gemm'
359+
360+ # Use node's alpha and beta values
321361 if dtype == dace .float32 :
322- alpha = "1.0f "
323- beta = "0.0f "
362+ alpha = f" { float ( node . alpha ) } f "
363+ beta = f" { float ( node . beta ) } f "
324364 prefix = "s"
325365 elif dtype == dace .float64 :
326- alpha = "1.0 "
327- beta = "0.0 "
366+ alpha = f" { float ( node . alpha ) } "
367+ beta = f" { float ( node . beta ) } "
328368 prefix = "d"
329369 elif dtype == dace .complex64 :
330- alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
331- beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
370+ if node .alpha == 1 :
371+ alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
372+ else :
373+ alpha = f"dace::blas::make_cuComplex({ node .alpha } , 0)"
374+ if node .beta == 0 :
375+ beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
376+ else :
377+ beta = f"dace::blas::make_cuComplex({ node .beta } , 0)"
332378 prefix = "c"
333379 elif dtype == dace .complex128 :
334- alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
335- beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
380+ if node .alpha == 1 :
381+ alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
382+ else :
383+ alpha = f"dace::blas::make_cuDoubleComplex({ node .alpha } , 0)"
384+ if node .beta == 0 :
385+ beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
386+ else :
387+ beta = f"dace::blas::make_cuDoubleComplex({ node .beta } , 0)"
336388 prefix = "z"
337389 else :
338390 raise ValueError ("Unsupported type for BLAS dot product: " + str (dtype ))
@@ -393,18 +445,31 @@ def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, as
393445 from dace .codegen .common import sym2cpp
394446
395447 prefix = to_blastype (dtype .type ).lower ()
448+ # Use node's alpha and beta values
396449 if dtype == dace .float32 :
397- alpha = "1.0f "
398- beta = "0.0f "
450+ alpha = f" { float ( node . alpha ) } f "
451+ beta = f" { float ( node . beta ) } f "
399452 elif dtype == dace .float64 :
400- alpha = "1.0 "
401- beta = "0.0 "
453+ alpha = f" { float ( node . alpha ) } "
454+ beta = f" { float ( node . beta ) } "
402455 elif dtype == dace .complex64 :
403- alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
404- beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
456+ if node .alpha == 1 :
457+ alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
458+ else :
459+ alpha = f"dace::blas::make_cuComplex({ node .alpha } , 0)"
460+ if node .beta == 0 :
461+ beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
462+ else :
463+ beta = f"dace::blas::make_cuComplex({ node .beta } , 0)"
405464 elif dtype == dace .complex128 :
406- alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
407- beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
465+ if node .alpha == 1 :
466+ alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
467+ else :
468+ alpha = f"dace::blas::make_cuDoubleComplex({ node .alpha } , 0)"
469+ if node .beta == 0 :
470+ beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
471+ else :
472+ beta = f"dace::blas::make_cuDoubleComplex({ node .beta } , 0)"
408473 else :
409474 raise ValueError ("Unsupported type for BLAS: " + str (dtype ))
410475
@@ -514,18 +579,31 @@ def expansion(node, state, sdfg):
514579 astrides , bstrides , dtype , is_a_1d , is_b_1d )
515580
516581 func = to_blastype (dtype .type ).lower () + 'gemm'
582+ # Use node's alpha and beta values
517583 if dtype == dace .float32 :
518- alpha = "1.0f "
519- beta = "0.0f "
584+ alpha = f" { float ( node . alpha ) } f "
585+ beta = f" { float ( node . beta ) } f "
520586 elif dtype == dace .float64 :
521- alpha = "1.0 "
522- beta = "0.0 "
587+ alpha = f" { float ( node . alpha ) } "
588+ beta = f" { float ( node . beta ) } "
523589 elif dtype == dace .complex64 :
524- alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
525- beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
590+ if node .alpha == 1 :
591+ alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
592+ else :
593+ alpha = f"dace::blas::make_cuComplex({ node .alpha } , 0)"
594+ if node .beta == 0 :
595+ beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
596+ else :
597+ beta = f"dace::blas::make_cuComplex({ node .beta } , 0)"
526598 elif dtype == dace .complex128 :
527- alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
528- beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
599+ if node .alpha == 1 :
600+ alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
601+ else :
602+ alpha = f"dace::blas::make_cuDoubleComplex({ node .alpha } , 0)"
603+ if node .beta == 0 :
604+ beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
605+ else :
606+ beta = f"dace::blas::make_cuDoubleComplex({ node .beta } , 0)"
529607 else :
530608 raise ValueError ("Unsupported type for BLAS dot product: " + str (dtype ))
531609 opt = _get_codegen_gemm_opts (node , state , sdfg , adesc , bdesc , cdesc , alpha , beta , cdesc .dtype .ctype , func )
@@ -612,26 +690,43 @@ def expansion(node, state, sdfg):
612690 1.0 : f"__state->cublas_handle.Constants(__dace_cuda_device).{ factort } Pone()" ,
613691 0.0 : f"__state->cublas_handle.Constants(__dace_cuda_device).{ factort } Zero()" ,
614692 }
693+
694+ # Handle alpha
615695 if node .alpha not in constants :
616696 # Deal with complex input constants
617697 if isinstance (node .alpha , complex ):
618- alpha = f'{ dtype .ctype } ({ node .alpha .real } , { node .alpha .imag } )'
698+ alpha_val = f'{ dtype .ctype } ({ node .alpha .real } , { node .alpha .imag } )'
619699 else :
620- alpha = f'{ dtype .ctype } ({ node .alpha } )'
700+ alpha_val = f'{ dtype .ctype } ({ node .alpha } )'
701+ use_host_mode_alpha = True
702+ else :
703+ alpha = constants [node .alpha ]
704+ use_host_mode_alpha = False
621705
622- # Set pointer mode to host
623- call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);
624- { dtype .ctype } alpha = { alpha } ;
625- { dtype .ctype } beta = 0;
626- '''
706+ # Handle beta
707+ if node .beta not in constants :
708+ # Deal with complex input constants
709+ if isinstance (node .beta , complex ):
710+ beta_val = f'{ dtype .ctype } ({ node .beta .real } , { node .beta .imag } )'
711+ else :
712+ beta_val = f'{ dtype .ctype } ({ node .beta } )'
713+ use_host_mode_beta = True
714+ else :
715+ beta = constants [node .beta ]
716+ use_host_mode_beta = False
717+
718+ # Set pointer mode to host if needed
719+ if use_host_mode_alpha or use_host_mode_beta :
720+ call_prefix += 'cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);\n '
721+ if use_host_mode_alpha :
722+ call_prefix += f' { dtype .ctype } alpha = { alpha_val } ;\n '
723+ alpha = f'({ cdtype } *)&alpha'
724+ if use_host_mode_beta :
725+ call_prefix += f' { dtype .ctype } beta = { beta_val } ;\n '
726+ beta = f'({ cdtype } *)&beta'
627727 call_suffix += '''
628728 cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
629729 '''
630- beta = f'({ cdtype } *)&beta'
631- alpha = f'({ cdtype } *)&alpha'
632- else :
633- alpha = constants [node .alpha ]
634- beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort
635730
636731 # Set up options for code formatting
637732 opt = _get_codegen_gemm_opts (node , state , sdfg , adesc , bdesc , cdesc , alpha , beta , cdtype , func )
0 commit comments