Skip to content

Conversation

@affifboudaoud
Copy link
Collaborator

@affifboudaoud affifboudaoud commented Oct 21, 2025

This PR adds a few more tests and special case fixes for the Batched Matrix Multiplication implementation in DaCe. Related to PR #2180. The overall goal is to add full NumPy/PyTorch matmul broadcasting support, including multi-dimensional batch broadcasting, 1D vector operations, and accumulation modes.

Changes

Broadcasting Support

  • 3D-4D broadcasting (5ffe768): Handle different batch dimension sizes (e.g., [b1, b2, m, k] @ [b2, k, n] = [b1, b2, m, n])
  • 1D vector broadcasting (3a88e77): Support 1D inputs as row/column vectors (e.g., [k] @ [b, k, n], [b, m, k] @ [k])
  • Added batch dimension alignment and indexing logic for partial broadcasting

Accumulation Support

  • Beta parameter (bf97502): Fixed beta scaling for C = alpha*A@B + beta*C in batched matmul, GEMM, and GEMV
  • Handles beta=0 (zero init), beta=1 (accumulate), and other values (scale then accumulate)

Test Files

  • batched_matmul_test.py: Tests for all broadcasting cases
    • 3D standard batched matmul
    • 3D-3D broadcasting (different batch sizes)
    • 3D-4D, 4D-5D broadcasting
    • 1D-3D, 1D-4D vector broadcasting
  • test_matmul_accumulate.py: Tests for beta parameter
    • Batched matmul accumulation
    • GEMM accumulation
    • GEMV accumulation

@affifboudaoud affifboudaoud marked this pull request as ready for review October 29, 2025 17:13
@phschaad phschaad requested a review from tbennun October 30, 2025 09:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants