Skip to content

Commit f5f64e0

Browse files
[moe fp8 training] test and bench new faster method for per group rowwise scaling
stack-info: PR: #2863, branch: danielvegamyhre/stack/57
1 parent 253d65a commit f5f64e0

File tree

3 files changed

+353
-68
lines changed

3 files changed

+353
-68
lines changed

benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py renamed to benchmarks/prototype/moe_training/benchmark_per_group_colwise_scaling_kernels.py

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616

1717
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_per_group_colwise_scales,
19-
triton_fp8_per_group_rowwise_scales,
2019
)
2120
from torchao.prototype.moe_training.utils import (
2221
generate_jagged_offs,
2322
torch_to_float8_per_group_colwise,
24-
torch_to_float8_per_group_rowwise,
2523
)
2624

2725
device = torch.device("cuda")
@@ -39,7 +37,7 @@ class ExperimentConfig:
3937

4038
@dataclass(frozen=True)
4139
class ExperimentResult:
42-
torch_time_us: float
40+
torch_loop_time_us: float
4341
triton_time_us: float
4442
torch_mem_bw_gbps: float
4543
triton_mem_bw_gbps: float
@@ -53,7 +51,7 @@ class Experiment:
5351

5452
def get_configs() -> List[ExperimentConfig]:
5553
input_shapes = [(16640, 5120)] # (Mg, K)
56-
n_groups_list = [1, 16, 128]
54+
n_groups_list = [1, 16, 64]
5755
high_precision_dtypes = [torch.bfloat16]
5856
configs = []
5957
for input_shape, n_groups, high_precision_dtype in itertools.product(
@@ -70,85 +68,88 @@ def get_configs() -> List[ExperimentConfig]:
7068

7169

7270
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
73-
# define test inputs
74-
input_tensor = torch.randn(
75-
*config.input_shape,
76-
dtype=config.high_precision_dtype,
77-
device=device,
71+
# Define test inputs
72+
Mg, K = config.input_shape
73+
74+
# Column major input tensor.
75+
# Right operand in grad_weight = grad_output_t @ input
76+
input_tensor = (
77+
torch.randn(
78+
Mg,
79+
K,
80+
dtype=config.high_precision_dtype,
81+
device=device,
82+
)
83+
.transpose(-2, -1)
84+
.contiguous()
85+
.transpose(-2, -1)
7886
)
79-
input_row_major = input_tensor.clone().detach()
80-
input_col_major = input_tensor.clone().detach().t()
8187

8288
# - configure input to be row-major with groups divided along the column dimension,
8389
# representing the left operand of grad_weight = grad_output_t @ input
8490
# that occurs in the backward pass of the differentiable scaled grouped mm.
8591
# - the transposed tensor in col-major format with groups along the row dimension,
8692
# which represents the right operand.
8793
n_groups = config.n_groups
88-
Mg = input_row_major.shape[0]
8994
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)
9095

9196
def warmup(func, *args, **kwargs):
9297
for _ in range(10):
9398
func(*args, **kwargs)
9499

95-
def run_torch(
96-
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
97-
):
98-
_ = torch_to_float8_per_group_rowwise(
99-
input_row_major,
100-
offs,
101-
target_dtype=torch.float8_e4m3fn,
102-
round_scales_to_power_of_2=True,
103-
)
104-
_ = torch_to_float8_per_group_colwise(
105-
input_col_major,
106-
offs,
107-
target_dtype=torch.float8_e4m3fn,
108-
round_scales_to_power_of_2=True,
109-
)
110-
111-
def run_triton(
112-
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
113-
):
114-
_ = triton_fp8_per_group_rowwise_scales(
115-
input_row_major,
116-
offs,
117-
output_dtype=torch.float8_e4m3fn,
118-
round_scales_to_power_of_2=True,
119-
)
120-
_ = triton_fp8_per_group_colwise_scales(
121-
input_col_major,
122-
offs,
123-
output_dtype=torch.float8_e4m3fn,
124-
round_scales_to_power_of_2=True,
125-
)
126-
127-
# bench torch
128-
compiled_run_torch = torch.compile(run_torch)
129-
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
130-
torch_time_us = benchmark_cuda_function_in_microseconds(
131-
compiled_run_torch, input_row_major, input_col_major, offs
100+
# Bench torch per group colwise
101+
torch_to_float8_per_group_colwise_c = torch.compile(
102+
torch_to_float8_per_group_colwise
103+
)
104+
warmup(
105+
torch_to_float8_per_group_colwise_c,
106+
input_tensor,
107+
offs,
108+
target_dtype=torch.float8_e4m3fn,
109+
)
110+
torch_loop_time_us = benchmark_cuda_function_in_microseconds(
111+
torch_to_float8_per_group_colwise_c,
112+
input_tensor,
113+
offs,
114+
target_dtype=torch.float8_e4m3fn,
132115
)
133116

134-
# bench triton
135-
warmup(run_triton, input_row_major, input_col_major, offs)
117+
# Bench triton per group colwise
118+
warmup(
119+
triton_fp8_per_group_colwise_scales,
120+
input_tensor,
121+
offs,
122+
output_dtype=torch.float8_e4m3fn,
123+
round_scales_to_power_of_2=True,
124+
)
136125
triton_time_us = benchmark_cuda_function_in_microseconds(
137-
run_triton, input_row_major, input_col_major, offs
126+
triton_fp8_per_group_colwise_scales,
127+
input_tensor,
128+
offs,
129+
output_dtype=torch.float8_e4m3fn,
130+
round_scales_to_power_of_2=True,
138131
)
139132

140-
# mem bw calculations - excluding scales to simplify calculation
141-
# but still get an accurate estimate.
133+
# Mem bw calculations
142134
bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8
143135
num_elements = input_tensor.numel()
144-
read_bytes = num_elements * bytes_per_input_el
145-
write_bytes = num_elements # 1 byte per element in float8_e4m3fn
136+
read_bytes = (
137+
2 * num_elements * bytes_per_input_el # read input tensor twice
138+
+ 4 * (n_groups * K) # read scales tensor once, 4 bytes per fp32 scale
139+
)
140+
write_bytes = (
141+
# 1 byte per output elem in fp8
142+
num_elements
143+
+
144+
# write scales tensor, 4 bytes per fp32 scale (we actually do this write once per blong along the reduction dim using atomics, but this is an approximation)
145+
4 * (n_groups * K)
146+
)
146147
read_write_bytes = read_bytes + write_bytes
147-
torch_mem_bw_gbps = (read_write_bytes) / (torch_time_us / 1e6) / 1e9
148+
torch_mem_bw_gbps = (read_write_bytes) / (torch_loop_time_us / 1e6) / 1e9
148149
triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9
149150

150151
return ExperimentResult(
151-
torch_time_us=torch_time_us,
152+
torch_loop_time_us=torch_loop_time_us,
152153
triton_time_us=triton_time_us,
153154
torch_mem_bw_gbps=torch_mem_bw_gbps,
154155
triton_mem_bw_gbps=triton_mem_bw_gbps,
@@ -157,10 +158,10 @@ def run_triton(
157158

158159
def print_results(experiments: List[Experiment]):
159160
headers = [
160-
"input_shape",
161+
"Mg,K",
161162
"n_groups",
162163
"high_precision_dtype",
163-
"torch_time_us",
164+
"torch_loop_time_us",
164165
"triton_time_us",
165166
"torch_mem_bw_gbps",
166167
"triton_mem_bw_gbps",
@@ -176,18 +177,18 @@ def print_results(experiments: List[Experiment]):
176177
input_shape,
177178
experiment.config.n_groups,
178179
experiment.config.high_precision_dtype,
179-
experiment.result.torch_time_us,
180+
experiment.result.torch_loop_time_us,
180181
experiment.result.triton_time_us,
181182
round(experiment.result.torch_mem_bw_gbps, 3),
182183
round(experiment.result.triton_mem_bw_gbps, 3),
183-
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
184+
f"{experiment.result.torch_loop_time_us / experiment.result.triton_time_us:.2f}x",
184185
]
185186
)
186187
print(tabulate(rows, headers=headers))
187188

188189

189-
def benchmark_cuda_function_in_microseconds(f, *args):
190-
return do_bench(lambda: f(*args), return_mode="median") * 1e3
190+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
191+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
191192

192193

193194
def main():

0 commit comments

Comments
 (0)