Skip to content

Commit ee9f562

Browse files
[moe fp8 training] use transpose method when quantizing to avoid uncoalesced gmem accesses
stack-info: PR: #2864, branch: danielvegamyhre/stack/58
1 parent f5f64e0 commit ee9f562

File tree

6 files changed

+44
-44
lines changed

6 files changed

+44
-44
lines changed

benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
class ExperimentConfig:
3232
high_precision_dtype: torch.dtype
3333
input_shape: tuple[int]
34+
power_of_2_scales: bool
3435

3536

3637
@dataclass(frozen=True)
@@ -48,7 +49,7 @@ class Experiment:
4849

4950

5051
def get_configs() -> List[ExperimentConfig]:
51-
# Llama4 shapes
52+
# Llama4 shapes (E, N, K)
5253
input_shapes = [
5354
(1, 8192, 5120), # w1, w3
5455
(1, 5120, 8192), # w2
@@ -58,14 +59,16 @@ def get_configs() -> List[ExperimentConfig]:
5859
(128, 5120, 8192), # w2
5960
]
6061
high_precision_dtypes = [torch.bfloat16]
62+
power_of_2_scales = [True, False]
6163
configs = []
62-
for input_shape, high_precision_dtype in itertools.product(
63-
input_shapes, high_precision_dtypes
64+
for input_shape, high_precision_dtype, power_of_2_scale in itertools.product(
65+
input_shapes, high_precision_dtypes, power_of_2_scales
6466
):
6567
configs.append(
6668
ExperimentConfig(
6769
input_shape=input_shape,
6870
high_precision_dtype=high_precision_dtype,
71+
power_of_2_scales=power_of_2_scale,
6972
)
7073
)
7174
return configs
@@ -87,18 +90,16 @@ def run_torch(input_tensor: torch.Tensor):
8790
out = torch_to_3d_rowwise_float8_transpose_rhs(
8891
input_tensor,
8992
target_dtype=torch.float8_e4m3fn,
90-
round_scales_to_power_of_2=True,
93+
round_scales_to_power_of_2=config.power_of_2_scales,
9194
)
92-
torch.cuda.synchronize()
9395
return out
9496

9597
def run_triton(input_tensor: torch.Tensor):
9698
out = triton_fp8_rowwise_3d_transpose_rhs(
9799
input_tensor,
98100
output_dtype=torch.float8_e4m3fn,
99-
round_scales_to_power_of_2=True,
101+
round_scales_to_power_of_2=config.power_of_2_scales,
100102
)
101-
torch.cuda.synchronize()
102103
return out
103104

104105
# bench torch
@@ -141,6 +142,7 @@ def run_triton(input_tensor: torch.Tensor):
141142
def print_results(experiments: List[Experiment]):
142143
headers = [
143144
"input_shape",
145+
"power_of_2_scales",
144146
"torch_time_us",
145147
"triton_time_us",
146148
"torch_mem_bw_gbps",
@@ -153,6 +155,7 @@ def print_results(experiments: List[Experiment]):
153155
rows.append(
154156
[
155157
input_shape,
158+
experiment.config.power_of_2_scales,
156159
experiment.result.torch_time_us,
157160
experiment.result.triton_time_us,
158161
round(experiment.result.torch_mem_bw_gbps, 3),

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_configs() -> List[ExperimentConfig]:
4949
# Llama4 shapes
5050
A_shapes = [(16640, 5120)]
5151
B_shapes = [(16, 8192, 5120)]
52-
recipes = [MoEScalingType.MXFP8, MoEScalingType.FP8_ROWWISE]
52+
recipes = [MoEScalingType.FP8_ROWWISE] # MoEScalingType.MXFP8,
5353
high_precision_dtypes = [torch.bfloat16]
5454
configs = []
5555
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
torch.float64: tl.float64,
2727
}
2828

29-
block_sizes_n = [32, 128, 256] # large dim (output_features)
30-
block_sizes_k = [32, 128, 256] # small dim (input_features)
31-
num_warps = [2, 4]
32-
num_stages = [2, 3, 4, 5, 6]
29+
block_sizes_n = [128] # large dim (output_features)
30+
block_sizes_k = [128] # small dim (input_features)
31+
num_warps = [4]
32+
num_stages = [4]
3333
kernel_configs_2D = [
3434
triton.Config(
3535
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
@@ -172,9 +172,7 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
172172
+ (n_offs[None, :] * stride_input_dim2)
173173
)
174174
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
175-
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
176-
input_dtype
177-
)
175+
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0)
178176

179177
# In a normal torch implementation, we should transpose the tensor then compute the amax
180178
# along the dim1 (N), to compute colwise scales for a RHS operand of a scaled grouped gemm:
@@ -243,25 +241,20 @@ def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
243241
+ (n_offs[None, :] * stride_input_dim2)
244242
)
245243
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
246-
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
247-
input_dtype
248-
)
244+
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0)
249245
input_data = input_data.trans(1, 0) # (K, N) -> (N, K)
250246

251247
# load global scales for this block of the given expert - shape (1, K)
252248
scales_offs = (
253249
expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1
254250
)
255251
scales_mask = k_offs[None, :] < K
256-
scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0).to(
257-
tl.float32
258-
)
252+
scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0)
259253

260254
# transpose data and apply scales - shape (N,K) * (1,K) = (N,K)
261-
scaled_data = input_data * scales
262-
output_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
263-
output_dtype
264-
)
255+
output_data = tl.clamp(
256+
input_data * scales, min=fp8_dtype_min, max=fp8_dtype_max
257+
).to(output_dtype)
265258

266259
# store transpose and store output data - shape (N, K)
267260
output_offs = (

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
torch.float64: tl.float64,
3232
}
3333

34-
block_sizes = [1, 16, 32, 64]
35-
block_sizes_iter = [64, 128, 256]
34+
block_sizes = [32] # [16, 32, 64]
35+
block_sizes_iter = [128] # [64, 128, 256]
3636
num_warps = [4]
3737
num_stages = [3]
3838
kernel_configs_2D = [

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1515
from torchao.prototype.moe_training.kernels import (
1616
triton_fp8_per_group_colwise_scales,
17-
triton_fp8_per_group_rowwise_scales,
1817
triton_fp8_rowwise_3d_transpose_rhs,
1918
)
2019
from torchao.prototype.moe_training.utils import (
@@ -174,8 +173,8 @@ def backward(ctx, grad_output: torch.Tensor):
174173
# Convert grad_output to float8, row-major for left operand of grouped GEMM
175174
# needed for grad_A: grad_output @ B
176175
#
177-
# grad_output shape: (M, N)
178-
# grad_output_scale shape: (M, 1)
176+
# grad_output shape: (Mg, N)
177+
# grad_output_scale shape: (Mg, 1)
179178
grad_output_scales = tensor_to_scale(
180179
grad_output,
181180
torch.float8_e4m3fn,
@@ -226,17 +225,22 @@ def backward(ctx, grad_output: torch.Tensor):
226225

227226
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
228227
# needed for grad_B: grad_output_t @ A
229-
grad_output_t_fp8_row_major, grad_output_t_scales = (
230-
triton_fp8_per_group_rowwise_scales(
231-
grad_output.transpose(-2, -1),
232-
offs,
233-
torch.float8_e4m3fn,
234-
round_scales_to_power_of_2=True,
235-
)
228+
# Use transpose method to avoid uncoalesced memory accesses.
229+
grad_out_fp8_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
230+
grad_output.t()
231+
.contiguous()
232+
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
233+
offs,
234+
torch.float8_e4m3fn,
235+
round_scales_to_power_of_2=True,
236236
)
237+
grad_output_t_fp8_row_major = grad_out_fp8_colwise.t()
238+
grad_output_t_scales = grad_out_scales.t()
237239

238240
A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales(
239-
A,
241+
A.t()
242+
.contiguous()
243+
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
240244
offs,
241245
torch.float8_e4m3fn,
242246
round_scales_to_power_of_2=True,

torchao/prototype/moe_training/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,21 @@ def torch_to_3d_rowwise_float8_transpose_rhs(
163163
Scales shape: (E, 1, K
164164
"""
165165
assert _is_column_major(input_hp_t), "input tensor must be column-major"
166-
input_hp = input_hp_t.transpose(-2, -1) # (E, N, K)
167166
scales = tensor_to_scale(
168-
input_hp,
167+
input_hp_t,
169168
target_dtype,
170169
scaling_granularity=ScalingGranularity.AXISWISE,
171-
axiswise_dim=-2,
170+
axiswise_dim=-1,
172171
round_scales_to_power_of_2=round_scales_to_power_of_2,
173-
) # (E, 1, K)
172+
) # (E, K, 1)
174173

175174
# Apply scales to tensor and convert to float8.
176-
tensor_scaled = input_hp.to(torch.float32) * scales
175+
tensor_scaled = input_hp_t.to(torch.float32) * scales
177176
float8_tensor = to_fp8_saturated(tensor_scaled, target_dtype)
178177

179178
# To column major
180-
float8_tensor = float8_tensor.transpose(-2, -1).contiguous().transpose(-2, -1)
179+
float8_tensor = float8_tensor.contiguous().transpose(-2, -1)
180+
scales = scales.transpose(-2, -1)
181181
return float8_tensor, scales
182182

183183

0 commit comments

Comments
 (0)