Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions dace/frontend/python/replacements/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

if len(arr1.shape) > 1 and len(arr2.shape) > 1: # matrix * matrix

if len(arr1.shape) > 3 or len(arr2.shape) > 3:
raise SyntaxError('Matrix multiplication of tensors of dimensions > 3 not supported')

res = symbolic.equal(arr1.shape[-1], arr2.shape[-2])
if res is None:
warnings.warn(
Expand All @@ -41,10 +38,12 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op

from dace.libraries.blas.nodes.matmul import _get_batchmm_opts

# Determine batched multiplication
# Determine batched multiplication (supports N-D tensors)
bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, None, None)
if bopt:
output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1])
# Multi-dimensional batch: use batch_dims if available, otherwise use flattened batch size
batch_dims = bopt.get('batch_dims', [bopt['b']])
output_shape = tuple(batch_dims) + (arr1.shape[-2], arr2.shape[-1])
else:
output_shape = (arr1.shape[-2], arr2.shape[-1])

Expand Down
81 changes: 63 additions & 18 deletions dace/libraries/blas/nodes/batched_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def make_sdfg(node, parent_state, parent_sdfg):
UserWarning)
elif not res:
raise SyntaxError("Matrix sizes must match")

# Determine output shape based on batch options
if bopt:
shape_c = (bopt['b'], shape_a[-2], shape_b[-1])
# Use batch dimensions from bopt (may be multi-dimensional)
batch_dims = bopt.get('batch_dims', [bopt['b']])
shape_c = tuple(batch_dims) + (shape_a[-2], shape_b[-1])
else:
shape_c = (shape_a[-2], shape_b[-1])

Expand Down Expand Up @@ -64,16 +68,46 @@ def make_sdfg(node, parent_state, parent_sdfg):

state = sdfg.add_state_after(init_state, node.label + "_state")

state.add_mapped_tasklet(
'_BatchedBatchedMatMult_', {
'__i%d' % i: '0:%s' % s
for i, s in enumerate([bopt['b'], array_a.shape[-2], array_b.shape[-1], array_a.shape[-1]])
}, {
'__a': dace.Memlet.simple("_a", ('__i1, __i3' if len(array_a.shape) == 2 else '__i0, __i1, __i3')),
'__b': dace.Memlet.simple("_b", ('__i3, __i2' if len(array_b.shape) == 2 else '__i0, __i3, __i2'))
},
'__c = __a * __b', {'__c': dace.Memlet.simple("_c", '__i0, __i1, __i2', wcr_str='lambda x, y: x + y')},
external_edges=True)
# Calculate number of batch dimensions in output
num_batch_dims = len(shape_c) - 2

# Build map parameters: batch dimensions + M, N, K
map_params = {}
for i in range(num_batch_dims):
map_params['__i%d' % i] = '0:%s' % symstr(shape_c[i])

# M, N, K dimensions
map_params['__im'] = '0:%s' % symstr(shape_a[-2])
map_params['__in'] = '0:%s' % symstr(shape_b[-1])
map_params['__ik'] = '0:%s' % symstr(shape_a[-1])

# Build memlet access patterns
# For A: if 2D, use [M, K]; if 3D+, use [batch_indices..., M, K]
if len(array_a.shape) == 2:
memlet_a = '__im, __ik'
else:
# Use output batch indices
a_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_a.shape) - 2)])
memlet_a = f'{a_batch_indices}, __im, __ik'

# For B: if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N]
if len(array_b.shape) == 2:
memlet_b = '__ik, __in'
else:
b_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_b.shape) - 2)])
memlet_b = f'{b_batch_indices}, __ik, __in'

# For C: always has batch dimensions
c_indices = ', '.join(['__i%d' % i for i in range(num_batch_dims)]) + ', __im, __in'

state.add_mapped_tasklet('_BatchedMatMult_',
map_params, {
'__a': dace.Memlet.simple("_a", memlet_a),
'__b': dace.Memlet.simple("_b", memlet_b)
},
'__c = __a * __b',
{'__c': dace.Memlet.simple("_c", c_indices, wcr_str='lambda x, y: x + y')},
external_edges=True)

return sdfg

Expand Down Expand Up @@ -441,20 +475,31 @@ def validate(self, sdfg, state):
raise ValueError("Expected exactly one output from "
"batched matrix-matrix product")
out_memlet = out_edges[0].data
# Function is symmetric, edge order does not matter
if len(size0) not in [2, 3]:
raise ValueError("Batched matrix-matrix product only supported on matrices")
if len(size1) != 3:
raise ValueError("Batched matrix-matrix product only supported on matrices")

# Both inputs must be at least 2D
if len(size0) < 2:
raise ValueError(f"First input must be at least 2D, got shape with {len(size0)} dimensions")
if len(size1) < 2:
raise ValueError(f"Second input must be at least 2D, got shape with {len(size1)} dimensions")

# At least one input must have batch dimensions (3D or higher) for batched operation
if len(size0) <= 2 and len(size1) <= 2:
raise ValueError(
"Batched matrix-matrix product requires at least one input to have batch dimensions (3D or higher)")

# Validate K-dimension compatibility
res = equal(size0[-1], size1[-2])
if res is None:
warnings.warn(
f'First tensor\'s last mode {size0[-1]} and second tensor\'s second-last mode {size1[-2]} '
f'may not match', UserWarning)
elif not res:
raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension")
if len(out_memlet.subset) != 3:
raise ValueError("batched matrix-matrix product only supported on matrices")

# Output must have batch dimensions
if len(out_memlet.subset) < 3:
raise ValueError(
f"Batched matrix-matrix product output must be at least 3D, got {len(out_memlet.subset)} dimensions")


# Numpy replacement
Expand Down
115 changes: 83 additions & 32 deletions dace/libraries/blas/nodes/matmul.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace import properties, symbolic
from copy import deepcopy as dc
from typing import Any, Dict
import warnings
from math import prod


def _get_matmul_operands(node, state, sdfg, name_lhs="_a", name_rhs="_b", name_out="_c"):
Expand Down Expand Up @@ -46,43 +47,88 @@ def _get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, c_strides
and returns its parameters (strides, batch size), or an empty dictionary if
batched multiplication is not detected.

:param a: Data descriptor for the first tensor.
:param b: Data descriptor for the second tensor.
:param c: Data descriptor for the output tensor (optional).
Supports N-dimensional tensors where all leading dimensions (except last 2) are treated
as batch dimensions. For example:
- [b, m, k] @ [b, k, n] -> [b, m, n]
- [b1, b2, m, k] @ [b1, b2, k, n] -> [b1, b2, m, n] (flattened to [b1*b2, m, k])
- [b, m, k] @ [k, n] -> [b, m, n]
- [m, k] @ [b, k, n] -> [b, m, n]

:param a_shape: Shape of the first tensor.
:param a_strides: Strides of the first tensor.
:param b_shape: Shape of the second tensor.
:param b_strides: Strides of the second tensor.
:param c_shape: Shape of the output tensor (optional).
:param c_strides: Strides of the output tensor (optional).
:return: A dictionary with the following keys: sa,sb,sc (strides for a, b,
and c); and b (batch size).
and c); and b (batch size). Empty dict if not batched.
"""
if len(a_shape) > 3 or len(b_shape) > 3 or (c_shape and len(c_shape) > 3):
raise ValueError('Tensor dimensions too large for (batched) matrix '
'multiplication')
# Both inputs must be at least 2D, and at least one must have batch dimensions
if len(a_shape) <= 2 and len(b_shape) <= 2:
return {}

batch = None
stride_a, stride_b, stride_c = 0, 0, 0
if len(a_shape) == 3:
batch = a_shape[0]
stride_a = a_strides[0]
if len(b_shape) == 3:
if batch is not None:
res = symbolic.equal(batch, b_shape[0])
# Calculate batch dimensions (all dimensions except last 2)
a_batch_dims = a_shape[:-2] if len(a_shape) > 2 else ()
b_batch_dims = b_shape[:-2] if len(b_shape) > 2 else ()
c_batch_dims = c_shape[:-2] if (c_shape and len(c_shape) > 2) else ()

# Determine the output batch shape using broadcasting rules
# Start with the longer batch shape and validate compatibility
if len(a_batch_dims) >= len(b_batch_dims):
result_batch_dims = list(a_batch_dims)
shorter_dims = b_batch_dims
longer_dims = a_batch_dims
else:
result_batch_dims = list(b_batch_dims)
shorter_dims = a_batch_dims
longer_dims = b_batch_dims

# Validate broadcasting compatibility for batch dimensions
if shorter_dims:
offset = len(longer_dims) - len(shorter_dims)
for i, (s_dim, l_dim) in enumerate(zip(shorter_dims, longer_dims[offset:])):
res = symbolic.equal(s_dim, l_dim)
if res is False and s_dim != 1 and l_dim != 1:
raise ValueError(f'Batch dimension mismatch: {s_dim} vs {l_dim} at position {i}')
if res is None:
warnings.warn(f'Batch size of first tensor ({batch}) may not match second tensor ({b_shape[0]})',
UserWarning)
elif not res:
raise ValueError('Batch size mismatch for matrix multiplication')
batch = b_shape[0]
stride_b = b_strides[0]
if c_shape and len(c_shape) == 3:
if batch and batch != c_shape[0]:
raise ValueError('Batch size mismatch for matrix multiplication')
batch = c_shape[0]
stride_c = c_strides[0]

if batch is None:
warnings.warn(f'Batch dimension {s_dim} may not match {l_dim} at position {i}', UserWarning)
# Use the non-1 dimension for broadcasting
if s_dim == 1 and l_dim != 1:
result_batch_dims[offset + i] = l_dim
elif l_dim == 1 and s_dim != 1:
result_batch_dims[offset + i] = s_dim

# Calculate total flattened batch size
batch_size = prod(result_batch_dims) if result_batch_dims else 1

# Calculate strides for batched operations
# For a tensor with shape [B1, B2, ..., M, K], the stride for batched operations
# should be M*K (the size of each matrix) to iterate through all matrices in the flattened batch
stride_a = 0
stride_b = 0
stride_c = 0

if len(a_shape) > 2:
# Stride for accessing each matrix: product of last two dimensions (M x K)
stride_a = prod(a_shape[-2:])

if len(b_shape) > 2:
# Stride for accessing each matrix: product of last two dimensions (K x N)
stride_b = prod(b_shape[-2:])

if c_shape and len(c_shape) > 2:
# Stride for accessing each matrix: product of last two dimensions (M x N)
stride_c = prod(c_shape[-2:])
# Validate output batch dimensions
for i, (c_dim, r_dim) in enumerate(zip(c_batch_dims, result_batch_dims)):
res = symbolic.equal(c_dim, r_dim)
if res is False:
raise ValueError(f'Output batch dimension mismatch: {c_dim} vs {r_dim} at position {i}')

if batch_size == 1 and not result_batch_dims:
return {}

return {'sa': stride_a, 'sb': stride_b, 'sc': stride_c, 'b': batch}
return {'sa': stride_a, 'sb': stride_b, 'sc': stride_c, 'b': batch_size, 'batch_dims': result_batch_dims}


def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) -> Dict[str, Any]:
Expand Down Expand Up @@ -151,6 +197,10 @@ def expansion(node, state, sdfg):
size_a = a[4]
size_b = b[4]
size_c = c[4]

# Check if this is a batched operation (at least one input has 3+ dimensions)
is_batched = len(size_a) >= 3 or len(size_b) >= 3

if len(size_c) == 2 and ((len(size_a) == 2 and len(size_b) == 2) or (len(a[2]) == 2 and len(b[2]) == 2)):
# Matrix and matrix -> GEMM
from dace.libraries.blas.nodes.gemm import Gemm
Expand All @@ -169,8 +219,9 @@ def expansion(node, state, sdfg):
"library node: {}".format(c[0].data.wcr))
gemm = Gemm(node.name + 'gemm', location=node.location, alpha=node.alpha, beta=beta, cin=cin)
return gemm
elif len(size_b) == 3 and (len(size_a) in [2, 3]):
# Batched matrix and matrix -> batched matrix multiplication
elif is_batched and len(size_a) >= 2 and len(size_b) >= 2:
# Batched matrix multiplication with broadcasting support
# Handles: [b, m, k] @ [b, k, n], [b, m, k] @ [k, n], [m, k] @ [b, k, n], [b1, b2, m, k] @ [b1, b2, k, n], etc.
from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul
result = BatchedMatMul(node.name + 'bmm', location=node.location)
elif len(size_a) == 2 and len(size_b) == 1:
Expand Down
Loading