Skip to content

Commit 2772a69

Browse files
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
1 parent b663faf commit 2772a69

File tree

6 files changed

+272
-54
lines changed

6 files changed

+272
-54
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,27 @@ def compute_reference_forward(
230230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
231231
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
232232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
233-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
233+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
234234
offs = generate_jagged_offs(num_experts, M)
235-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
235+
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()
236236

237237
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238
block_size = 32
239-
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
239+
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
240240

241241
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242-
w_scale, w_mx = to_mx(
243-
w_t.transpose(-2, -1).contiguous(),
242+
w_scale, w_fp8 = to_mx(
243+
w,
244244
elem_dtype=torch.float8_e4m3fn,
245245
block_size=block_size,
246246
)
247-
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
247+
w_t_scale, w_t_fp8 = w_scale.transpose(-2, -1), w_fp8.transpose(-2, -1)
248248

249-
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
249+
ref_out = torch._grouped_mm(
250+
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
251+
)
250252
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
251-
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
253+
x_fp8, x_scale, w_t_fp8, w_t_scale, offs=offs, out_dtype=torch.bfloat16
252254
)
253255

254256
sqnr = compute_error(ref_out, out)
@@ -305,19 +307,26 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
305307

306308

307309
@skip_if_rocm("ROCm not supported")
308-
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
309-
@pytest.mark.parametrize("num_experts", (1, 8, 16))
310+
@pytest.mark.parametrize("M,K,N", [(256, 512, 512)])
311+
@pytest.mark.parametrize("num_experts", (2,))
310312
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
311313
from torchao.prototype.moe_training.scaled_grouped_mm import (
312314
_MXFP8GroupedMM,
313315
)
314316

315317
block_size = 32
316318
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317-
w_t = torch.randn(
318-
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
319+
w = torch.randn(
320+
num_experts,
321+
N,
322+
K,
323+
dtype=torch.bfloat16,
324+
device="cuda",
319325
)
320-
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
326+
w_t = w.transpose(-2, -1).requires_grad_(True)
327+
# offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
328+
group_size = M // num_experts
329+
offs = torch.arange(group_size, M + 1, group_size, device="cuda", dtype=torch.int32)
321330
x_ref, w_t_ref, offs_ref = (
322331
x.clone().detach().requires_grad_(True),
323332
w_t.clone().detach().requires_grad_(True),

test/prototype/moe_training/test_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
136136
["does.not.exist"],
137137
],
138138
)
139-
@pytest.mark.parametrize("compile", [False, True])
139+
@pytest.mark.parametrize("compile", [False])
140140
def test_moe_mxfp8_training(target_fqns: list[str], compile: bool):
141141
block_size = 32
142142

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10+
from torchao.prototype.moe_training.kernels.mxfp8 import (
11+
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
12+
)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import logging
2+
3+
import torch
4+
5+
from torchao.prototype.mx_formats.utils import (
6+
to_blocked_per_group_2d,
7+
to_blocked_per_group_3d,
8+
)
9+
10+
logger: logging.Logger = logging.getLogger(__name__)
11+
12+
try:
13+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
14+
except Exception as e:
15+
logging.warning(
16+
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
17+
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
18+
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
19+
"If errors persist, please file a bug report."
20+
)
21+
22+
23+
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
24+
def fbgemm_mxfp8_grouped_mm_2d_3d(
25+
A_mx: torch.Tensor,
26+
A_scale: torch.Tensor,
27+
B_t_mx_dim1: torch.Tensor,
28+
B_t_scales_dim1: torch.Tensor,
29+
offs: torch.Tensor,
30+
block_size: int = 32,
31+
out_dtype: torch.dtype = torch.bfloat16,
32+
) -> torch.Tensor:
33+
assert A_mx.ndim == 2, "A_mx tensor must be 2D"
34+
assert B_t_mx_dim1.ndim == 3, "B_t_mx_dim1 tensor must be 3D"
35+
assert block_size == 32, "Only block_size=32 is supported"
36+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
37+
38+
# Convert scales for each group to blocked format.
39+
Mg, K = A_mx.shape
40+
A_scale_blocked, starting_row_after_padding = to_blocked_per_group_2d(
41+
A_scale, offs, Mg, K
42+
)
43+
B_t_scales_dim1_blocked = to_blocked_per_group_3d(B_t_scales_dim1)
44+
45+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
46+
# group_sizes = [32, 32, 64]
47+
# starting_row_after_padding = [0, 32, 64, 128]
48+
group_sizes = torch.diff(starting_row_after_padding).to(torch.int64)
49+
50+
# TODO: remove debug logging once prototype is more mature.
51+
logger.debug("A_mx.shape", A_mx.shape, "stride", A_mx.stride(), "dtype", A_mx.dtype)
52+
logger.debug(
53+
"B_t_mx_dim1.shape",
54+
B_t_mx_dim1.shape,
55+
"stride",
56+
B_t_mx_dim1.stride(),
57+
"dtype",
58+
B_t_mx_dim1.dtype,
59+
)
60+
logger.debug(
61+
"A_scales_blocked.shape",
62+
A_scale_blocked.shape,
63+
"stride",
64+
A_scale_blocked.stride(),
65+
"dtype",
66+
A_scale_blocked.dtype,
67+
)
68+
logger.debug(
69+
"B_t_scales_dim (non-blocked)",
70+
B_t_scales_dim1.shape,
71+
"stride",
72+
B_t_scales_dim1.stride(),
73+
"dtype",
74+
B_t_scales_dim1.dtype,
75+
)
76+
logger.debug(
77+
"B_t_scales_dim1_blocked.shape",
78+
B_t_scales_dim1_blocked.shape,
79+
"stride",
80+
B_t_scales_dim1_blocked.stride(),
81+
"dtype",
82+
B_t_scales_dim1_blocked.dtype,
83+
)
84+
logger.debug(
85+
"group_sizes",
86+
group_sizes,
87+
"group_sizes.stride",
88+
group_sizes.stride(),
89+
"dtype",
90+
group_sizes.dtype,
91+
)
92+
logger.debug(
93+
"starting_row_after_padding",
94+
starting_row_after_padding,
95+
"stride",
96+
starting_row_after_padding.stride(),
97+
"dtype",
98+
starting_row_after_padding.dtype,
99+
)
100+
101+
out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
102+
A_mx,
103+
B_t_mx_dim1,
104+
A_scale_blocked,
105+
B_t_scales_dim1_blocked,
106+
group_sizes,
107+
starting_row_after_padding=starting_row_after_padding,
108+
)
109+
return out
110+
111+
112+
@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
113+
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
114+
A_mx: torch.Tensor,
115+
B_t_mx_dim1: torch.Tensor,
116+
A_scale: torch.Tensor,
117+
B_t_scales_dim1: torch.Tensor,
118+
offs: torch.Tensor,
119+
) -> torch.Tensor:
120+
assert A_mx.ndim == 2, "A_mx tensor must be 2D"
121+
assert B_t_mx_dim1.ndim == 3, "B_t_mx_dim1 tensor must be 3D"
122+
mg, k = A_mx.shape
123+
e, k, n = B_t_mx_dim1.shape
124+
n_groups = offs.numel()
125+
assert n_groups == e, (
126+
"Size of `offs` (number of groups) must match first dim of `B_t_mx_dim1`"
127+
)
128+
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_mx.device)
129+
return output

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
1414
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1515
from torchao.prototype.moe_training.kernels import (
16+
fbgemm_mxfp8_grouped_mm_2d_3d,
1617
triton_fp8_per_group_colwise_scales,
1718
triton_fp8_per_group_rowwise_scales,
1819
triton_fp8_rowwise_3d_transpose_rhs,
@@ -277,52 +278,46 @@ def forward(
277278
offs: Optional[torch.Tensor] = None,
278279
block_size: int = 32,
279280
out_dtype: Optional[torch.dtype] = torch.bfloat16,
280-
emulated: bool = True,
281+
emulated: bool = False,
281282
) -> torch.Tensor:
282283
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
283284
assert A.ndim == 2, "A must be 2D"
284285
assert B_t.ndim == 3, "B must be 3D"
285286
assert block_size == 32, "Only block_size=32 is supported"
286-
assert emulated, "Only emulated mxfp8 grouped gemm is supported"
287+
288+
# Store what we need for backward.
289+
ctx.save_for_backward(A, B_t, offs)
290+
ctx.block_size = block_size
291+
ctx.out_dtype = out_dtype
292+
ctx.emulated = emulated
287293

288294
# Cast to mxpf8 across dim -1.
289295
# A_mx shape: (M, K)
290296
# A_scale shape: (M, K//block_size)
291297
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
292298

293-
# Cast B_t per-expert to mxfp8 across dim1.
294-
# B_t_mx shape: (E, K, N)
295-
# B_t_scale shape: (E, K//block_size, N)
296-
297-
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
299+
# Cast B_t per-expert to mxfp8 across K dim.
298300
# B_mx shape: (E, N, K)
299301
# B_scale shape: (E, N, K//block_size)
300-
B_scales_dim2, B_mx_dim2 = to_mx(
301-
B_t.transpose(-2, -1), # (E,K,N) -> (E,N,K)
302+
B_scales, B_mx = to_mx(
303+
B_t.transpose(-2, -1).contiguous(),
302304
elem_dtype=torch.float8_e4m3fn,
303305
block_size=block_size,
304306
)
305307

306-
# B_t_mx shape: (E, K, N)
307-
# B_t_scale shape: (E, K//block_size, N)
308-
B_t_scales_dim1 = B_scales_dim2.transpose(
309-
-2, -1
310-
) # (E,N,K//block_size) -> (E,K//block_size,N)
311-
B_t_mx_dim1 = B_mx_dim2.transpose(-2, -1) # (E,N,K) -> (E,K,N)
312-
313-
# Store what we need for backward.
314-
ctx.save_for_backward(A, B_t, offs)
315-
ctx.block_size = block_size
316-
ctx.out_dtype = out_dtype
317-
318308
# Perform scaled grouped GEMM and return result.
319309
# output = input @ weight.T
320310
# output shape: (M, N)
321-
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
311+
mxfp8_2d_3d_grouped_mm = (
312+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
313+
if emulated
314+
else fbgemm_mxfp8_grouped_mm_2d_3d
315+
)
316+
out = mxfp8_2d_3d_grouped_mm(
322317
A_mx,
323318
A_scale,
324-
B_t_mx_dim1,
325-
B_t_scales_dim1,
319+
B_mx,
320+
B_scales,
326321
offs=offs,
327322
block_size=block_size,
328323
out_dtype=out_dtype,
@@ -334,6 +329,7 @@ def backward(ctx, grad_out: torch.Tensor):
334329
A, B_t, offs = ctx.saved_tensors
335330
block_size = ctx.block_size
336331
out_dtype = ctx.out_dtype
332+
emulated = ctx.emulated
337333

338334
# grad_out_mx shape: (M, N)
339335
# grad_out_scale shape: (M, N//block_size)
@@ -343,23 +339,24 @@ def backward(ctx, grad_out: torch.Tensor):
343339

344340
# B_mx shape: (E, K, N)
345341
# B_scale shape: (E, K, N//block_size)
346-
B_t_scale_dim2, B_t_mx_dim2 = to_mx(
342+
B_scales, B_mx = to_mx(
347343
B_t.contiguous(),
348344
elem_dtype=torch.float8_e4m3fn,
349345
block_size=block_size,
350346
)
351-
B_scale_dim1 = B_t_scale_dim2.transpose(
352-
-2, -1
353-
) # (E,K,N//block_size) -> (E,N//block_size,K)
354-
B_mx_dim1 = B_t_mx_dim2.transpose(-2, -1) # (E,K,N) -> (E,N,K)
355347

356348
# Compute grad_A.
357349
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
358-
grad_A = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
350+
mxfp8_2d_3d_grouped_mm = (
351+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
352+
if emulated
353+
else fbgemm_mxfp8_grouped_mm_2d_3d
354+
)
355+
grad_A = mxfp8_2d_3d_grouped_mm(
359356
grad_out_mx,
360357
grad_out_scale,
361-
B_mx_dim1,
362-
B_scale_dim1,
358+
B_mx,
359+
B_scales,
363360
offs=offs,
364361
out_dtype=out_dtype,
365362
)
@@ -385,7 +382,6 @@ def backward(ctx, grad_out: torch.Tensor):
385382
# Compute grad_B = grad_output_t @ A
386383
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
387384
# grad_B = grad_B_t.transpose(-2, -1) = (E,K,N)
388-
389385
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
390386
grad_out_t_mx,
391387
grad_out_t_scales,
@@ -402,12 +398,24 @@ def backward(ctx, grad_out: torch.Tensor):
402398
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
403399
A_mx: torch.Tensor,
404400
A_scale: torch.Tensor,
405-
B_t_mx: torch.Tensor,
406-
B_t_scale: torch.Tensor,
401+
B_mx: torch.Tensor,
402+
B_scale: torch.Tensor,
407403
offs: Optional[torch.Tensor] = None,
408404
out_dtype: Optional[torch.dtype] = torch.bfloat16,
409405
block_size: int = 32,
410406
) -> torch.Tensor:
407+
assert A_mx.ndim == 2, "A must be 2D"
408+
assert B_mx.ndim == 3, "B must be 3D"
409+
assert A_scale.shape[0] == A_mx.shape[0], "A_scale must have same M dim as A_mx"
410+
assert A_scale.shape[1] == A_mx.shape[1] // block_size, (
411+
"A_scale dim1 should be size K//block_size"
412+
)
413+
assert B_scale.shape[0] == B_mx.shape[0], "B_scale must have same E dim as B_mx"
414+
assert B_scale.shape[1] == B_mx.shape[1], "B_scale must have same N dim as B_mx"
415+
assert B_scale.shape[2] == B_mx.shape[2] // block_size, (
416+
"B_scale dim2 should be size K//block_size"
417+
)
418+
411419
# Dequantize input
412420
# A_mx shape: (M, K)
413421
# A_scale shape: (M, K//block_size)
@@ -427,14 +435,10 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
427435
A = A.reshape(A_orig_shape)
428436

429437
# Dequantize weights
430-
# B_t_mx shape: (E, K, N)
431-
# B_t_scale shape: (E, K//block_size, N)
432-
E, K, N = B_t_mx.shape
433-
434438
# Tranpose to get block_size on rightmost dim
435439
# B_mx shape: (E, N, K)
436440
# B_scale shape: (E, N, K//block_size)
437-
B_mx, B_scale = B_t_mx.transpose(-2, -1), B_t_scale.transpose(-2, -1)
441+
E, N, K = B_mx.shape
438442

439443
# Reshape to be able to do per-scaling group multiplication
440444
# B_mx shape: (E, N, K//block_size, block_size)

0 commit comments

Comments
 (0)