@@ -224,6 +224,111 @@ def bmm_4d_broadcast(A: dtype[m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m,
224224 assert np .allclose (ref , z )
225225
226226
227+ @pytest .mark .parametrize ("implementation, dtype" , [
228+ pytest .param ("pure" , dace .float32 ),
229+ pytest .param ("pure" , dace .float64 ),
230+ pytest .param ("MKL" , dace .float32 , marks = pytest .mark .mkl ),
231+ pytest .param ("MKL" , dace .float64 , marks = pytest .mark .mkl ),
232+ pytest .param ("cuBLAS" , dace .float32 , marks = pytest .mark .gpu ),
233+ pytest .param ("cuBLAS" , dace .float64 , marks = pytest .mark .gpu ),
234+ pytest .param ("OpenBLAS" , dace .float32 , marks = pytest .mark .lapack ),
235+ pytest .param ("OpenBLAS" , dace .float64 , marks = pytest .mark .lapack )
236+ ])
237+ def test_batchmm_3d_4d_broadcast (implementation : str , dtype ):
238+ """Test 4D batched matmul with broadcast on LHS: [b2, m, k] @ [b1, b2, k, n]"""
239+ b1 , b2 , m , n , k = 4 , 2 , 64 , 128 , 64
240+
241+ @dace .program
242+ def bmm_3d_4d_broadcast (A : dtype [b2 , m , k ], B : dtype [b1 , b2 , k , n ], C : dtype [b1 , b2 , m , n ]):
243+ C [:] = A @ B
244+
245+ with change_default (blas , implementation ):
246+ sdfg = bmm_3d_4d_broadcast .to_sdfg ()
247+ sdfg .simplify ()
248+ sdfg .expand_library_nodes ()
249+
250+ x = np .random .rand (b2 , m , k ).astype (dtype .as_numpy_dtype ())
251+ y = np .random .rand (b1 , b2 , k , n ).astype (dtype .as_numpy_dtype ())
252+ z = np .zeros ([b1 , b2 , m , n ]).astype (dtype .as_numpy_dtype ())
253+
254+ csdfg = sdfg .compile ()
255+ csdfg (A = x , B = y , C = z )
256+
257+ ref = x @ y
258+
259+ assert np .allclose (ref , z )
260+
261+
262+ @pytest .mark .parametrize ("implementation, dtype" , [
263+ pytest .param ("pure" , dace .float32 ),
264+ pytest .param ("pure" , dace .float64 ),
265+ pytest .param ("MKL" , dace .float32 , marks = pytest .mark .mkl ),
266+ pytest .param ("MKL" , dace .float64 , marks = pytest .mark .mkl ),
267+ pytest .param ("cuBLAS" , dace .float32 , marks = pytest .mark .gpu ),
268+ pytest .param ("cuBLAS" , dace .float64 , marks = pytest .mark .gpu ),
269+ pytest .param ("OpenBLAS" , dace .float32 , marks = pytest .mark .lapack ),
270+ pytest .param ("OpenBLAS" , dace .float64 , marks = pytest .mark .lapack )
271+ ])
272+ def test_batchmm_4d_3d_broadcast (implementation : str , dtype ):
273+ """Test 4D batched matmul with broadcast on RHS: [b1, b2, m, k] @ [b2, k, n]"""
274+ b1 , b2 , m , n , k = 4 , 2 , 64 , 128 , 64
275+
276+ @dace .program
277+ def bmm_4d_3d_broadcast (A : dtype [b1 , b2 , m , k ], B : dtype [b2 , k , n ], C : dtype [b1 , b2 , m , n ]):
278+ C [:] = A @ B
279+
280+ with change_default (blas , implementation ):
281+ sdfg = bmm_4d_3d_broadcast .to_sdfg ()
282+ sdfg .simplify ()
283+ sdfg .expand_library_nodes ()
284+
285+ x = np .random .rand (b1 , b2 , m , k ).astype (dtype .as_numpy_dtype ())
286+ y = np .random .rand (b2 , k , n ).astype (dtype .as_numpy_dtype ())
287+ z = np .zeros ([b1 , b2 , m , n ]).astype (dtype .as_numpy_dtype ())
288+
289+ csdfg = sdfg .compile ()
290+ csdfg (A = x , B = y , C = z )
291+
292+ ref = x @ y
293+
294+ assert np .allclose (ref , z )
295+
296+
297+ @pytest .mark .parametrize ("implementation, dtype" , [
298+ pytest .param ("pure" , dace .float32 ),
299+ pytest .param ("pure" , dace .float64 ),
300+ pytest .param ("MKL" , dace .float32 , marks = pytest .mark .mkl ),
301+ pytest .param ("MKL" , dace .float64 , marks = pytest .mark .mkl ),
302+ pytest .param ("cuBLAS" , dace .float32 , marks = pytest .mark .gpu ),
303+ pytest .param ("cuBLAS" , dace .float64 , marks = pytest .mark .gpu ),
304+ pytest .param ("OpenBLAS" , dace .float32 , marks = pytest .mark .lapack ),
305+ pytest .param ("OpenBLAS" , dace .float64 , marks = pytest .mark .lapack )
306+ ])
307+ def test_batchmm_5d_3d_broadcast (implementation : str , dtype ):
308+ """Test 5D batched matmul with broadcast on RHS: [b1, b2, b3, m, k] @ [b3, k, n]"""
309+ b1 , b2 , b3 , m , n , k = 4 , 2 , 3 , 64 , 128 , 64
310+
311+ @dace .program
312+ def bmm_5d_3d_broadcast (A : dtype [b1 , b2 , b3 , m , k ], B : dtype [b3 , k , n ], C : dtype [b1 , b2 , b3 , m , n ]):
313+ C [:] = A @ B
314+
315+ with change_default (blas , implementation ):
316+ sdfg = bmm_5d_3d_broadcast .to_sdfg ()
317+ sdfg .simplify ()
318+ sdfg .expand_library_nodes ()
319+
320+ x = np .random .rand (b1 , b2 , b3 , m , k ).astype (dtype .as_numpy_dtype ())
321+ y = np .random .rand (b3 , k , n ).astype (dtype .as_numpy_dtype ())
322+ z = np .zeros ([b1 , b2 , b3 , m , n ]).astype (dtype .as_numpy_dtype ())
323+
324+ csdfg = sdfg .compile ()
325+ csdfg (A = x , B = y , C = z )
326+
327+ ref = x @ y
328+
329+ assert np .allclose (ref , z )
330+
331+
227332if __name__ == "__main__" :
228333 test_batchmm ("pure" , dace .float32 )
229334 test_batchmm ("pure" , dace .float64 )
@@ -261,3 +366,21 @@ def bmm_4d_broadcast(A: dtype[m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m,
261366 test_batchmm_4d_broadcast_lhs ("MKL" , dace .float64 )
262367 test_batchmm_4d_broadcast_lhs ("cuBLAS" , dace .float32 )
263368 test_batchmm_4d_broadcast_lhs ("cuBLAS" , dace .float64 )
369+ test_batchmm_3d_4d_broadcast ("pure" , dace .float32 )
370+ test_batchmm_3d_4d_broadcast ("pure" , dace .float64 )
371+ test_batchmm_3d_4d_broadcast ("MKL" , dace .float32 )
372+ test_batchmm_3d_4d_broadcast ("MKL" , dace .float64 )
373+ test_batchmm_3d_4d_broadcast ("cuBLAS" , dace .float32 )
374+ test_batchmm_3d_4d_broadcast ("cuBLAS" , dace .float64 )
375+ test_batchmm_4d_3d_broadcast ("pure" , dace .float32 )
376+ test_batchmm_4d_3d_broadcast ("pure" , dace .float64 )
377+ test_batchmm_4d_3d_broadcast ("MKL" , dace .float32 )
378+ test_batchmm_4d_3d_broadcast ("MKL" , dace .float64 )
379+ test_batchmm_4d_3d_broadcast ("cuBLAS" , dace .float32 )
380+ test_batchmm_4d_3d_broadcast ("cuBLAS" , dace .float64 )
381+ test_batchmm_5d_3d_broadcast ("pure" , dace .float32 )
382+ test_batchmm_5d_3d_broadcast ("pure" , dace .float64 )
383+ test_batchmm_5d_3d_broadcast ("MKL" , dace .float32 )
384+ test_batchmm_5d_3d_broadcast ("MKL" , dace .float64 )
385+ test_batchmm_5d_3d_broadcast ("cuBLAS" , dace .float32 )
386+ test_batchmm_5d_3d_broadcast ("cuBLAS" , dace .float64 )
0 commit comments