Skip to content

Conversation

@affifboudaoud
Copy link
Collaborator

Extended Batched Matrix Multiplication Support

Summary

Extend batched matrix multiplication to support N-dimensional tensors with NumPy-style broadcasting across all implementations (Pure, MKL, OpenBLAS, cuBLAS).

Changes

  • N-D tensor support: Handles tensors with arbitrary batch dimensions (e.g., [12, 2, 64, 64] @ [12, 2, 64, 128])
  • Broadcasting: Supports broadcasting patterns like [b, m, k] @ [k, n] and [m, k] @ [b, k, n]
  • Batch flattening: Multi-dimensional batches are flattened internally for efficient BLAS operations

New Capabilities

# 3D broadcasting
[b, m, k] @ [k, n] → [b, m, n]
[m, k] @ [b, k, n] → [b, m, n]

# 4D batched matmul
[12, 2, 64, 64] @ [12, 2, 64, 128] → [12, 2, 64, 128]
[12, 2, m, k] @ [k, n] → [12, 2, m, n]  # with broadcasting

Files Modified

  • dace/libraries/blas/nodes/matmul.py: Extended _get_batchmm_opts() for N-D tensors and broadcasting
  • dace/libraries/blas/nodes/batched_matmul.py: Updated validation and Pure expansion for dynamic dimensions
  • dace/frontend/python/replacements/linalg.py: Removed 3D tensor check
  • tests/library/batched_matmul_test.py: Added tests for newly supported 3D/4D batched matmuls with broadcasting

@affifboudaoud affifboudaoud marked this pull request as ready for review October 17, 2025 09:23
Copy link
Collaborator

@phschaad phschaad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition, looks good to me

@phschaad phschaad added this pull request to the merge queue Oct 18, 2025
Merged via the queue into main with commit d32f51d Oct 18, 2025
10 checks passed
@phschaad phschaad deleted the batched_matmul_improvements branch October 18, 2025 09:11
sophieblock pushed a commit to sophieblock/dace that referenced this pull request Oct 20, 2025
# Extended Batched Matrix Multiplication Support

## Summary
Extend batched matrix multiplication to support N-dimensional tensors
with NumPy-style broadcasting across all implementations (Pure, MKL,
OpenBLAS, cuBLAS).

## Changes
- **N-D tensor support**: Handles tensors with arbitrary batch
dimensions (e.g., `[12, 2, 64, 64] @ [12, 2, 64, 128]`)
- **Broadcasting**: Supports broadcasting patterns like `[b, m, k] @ [k,
n]` and `[m, k] @ [b, k, n]`
- **Batch flattening**: Multi-dimensional batches are flattened
internally for efficient BLAS operations

## New Capabilities
```python
# 3D broadcasting
[b, m, k] @ [k, n] → [b, m, n]
[m, k] @ [b, k, n] → [b, m, n]

# 4D batched matmul
[12, 2, 64, 64] @ [12, 2, 64, 128] → [12, 2, 64, 128]
[12, 2, m, k] @ [k, n] → [12, 2, m, n]  # with broadcasting
```

## Files Modified
- `dace/libraries/blas/nodes/matmul.py`: Extended `_get_batchmm_opts()`
for N-D tensors and broadcasting
- `dace/libraries/blas/nodes/batched_matmul.py`: Updated validation and
Pure expansion for dynamic dimensions
- `dace/frontend/python/replacements/linalg.py`: Removed 3D tensor check
- `tests/library/batched_matmul_test.py`: Added tests for newly
supported 3D/4D batched matmuls with broadcasting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants