16
16
17
17
from torchao .prototype .moe_training .kernels .jagged_float8_scales import (
18
18
triton_fp8_per_group_colwise_scales ,
19
- triton_fp8_per_group_rowwise_scales ,
20
19
)
21
20
from torchao .prototype .moe_training .utils import (
22
21
generate_jagged_offs ,
23
22
torch_to_float8_per_group_colwise ,
24
- torch_to_float8_per_group_rowwise ,
25
23
)
26
24
27
25
device = torch .device ("cuda" )
@@ -39,7 +37,7 @@ class ExperimentConfig:
39
37
40
38
@dataclass (frozen = True )
41
39
class ExperimentResult :
42
- torch_time_us : float
40
+ torch_loop_time_us : float
43
41
triton_time_us : float
44
42
torch_mem_bw_gbps : float
45
43
triton_mem_bw_gbps : float
@@ -53,7 +51,7 @@ class Experiment:
53
51
54
52
def get_configs () -> List [ExperimentConfig ]:
55
53
input_shapes = [(16640 , 5120 )] # (Mg, K)
56
- n_groups_list = [1 , 16 , 128 ]
54
+ n_groups_list = [1 , 16 , 64 ]
57
55
high_precision_dtypes = [torch .bfloat16 ]
58
56
configs = []
59
57
for input_shape , n_groups , high_precision_dtype in itertools .product (
@@ -70,85 +68,88 @@ def get_configs() -> List[ExperimentConfig]:
70
68
71
69
72
70
def run_experiment (config : ExperimentConfig ) -> ExperimentResult :
73
- # define test inputs
74
- input_tensor = torch .randn (
75
- * config .input_shape ,
76
- dtype = config .high_precision_dtype ,
77
- device = device ,
71
+ # Define test inputs
72
+ Mg , K = config .input_shape
73
+
74
+ # Column major input tensor.
75
+ # Right operand in grad_weight = grad_output_t @ input
76
+ input_tensor = (
77
+ torch .randn (
78
+ Mg ,
79
+ K ,
80
+ dtype = config .high_precision_dtype ,
81
+ device = device ,
82
+ )
83
+ .transpose (- 2 , - 1 )
84
+ .contiguous ()
85
+ .transpose (- 2 , - 1 )
78
86
)
79
- input_row_major = input_tensor .clone ().detach ()
80
- input_col_major = input_tensor .clone ().detach ().t ()
81
87
82
88
# - configure input to be row-major with groups divided along the column dimension,
83
89
# representing the left operand of grad_weight = grad_output_t @ input
84
90
# that occurs in the backward pass of the differentiable scaled grouped mm.
85
91
# - the transposed tensor in col-major format with groups along the row dimension,
86
92
# which represents the right operand.
87
93
n_groups = config .n_groups
88
- Mg = input_row_major .shape [0 ]
89
94
offs = generate_jagged_offs (n_groups , Mg , multiple_of = 16 )
90
95
91
96
def warmup (func , * args , ** kwargs ):
92
97
for _ in range (10 ):
93
98
func (* args , ** kwargs )
94
99
95
- def run_torch (
96
- input_row_major : torch .Tensor , input_col_major : torch .Tensor , offs : torch .Tensor
97
- ):
98
- _ = torch_to_float8_per_group_rowwise (
99
- input_row_major ,
100
- offs ,
101
- target_dtype = torch .float8_e4m3fn ,
102
- round_scales_to_power_of_2 = True ,
103
- )
104
- _ = torch_to_float8_per_group_colwise (
105
- input_col_major ,
106
- offs ,
107
- target_dtype = torch .float8_e4m3fn ,
108
- round_scales_to_power_of_2 = True ,
109
- )
110
-
111
- def run_triton (
112
- input_row_major : torch .Tensor , input_col_major : torch .Tensor , offs : torch .Tensor
113
- ):
114
- _ = triton_fp8_per_group_rowwise_scales (
115
- input_row_major ,
116
- offs ,
117
- output_dtype = torch .float8_e4m3fn ,
118
- round_scales_to_power_of_2 = True ,
119
- )
120
- _ = triton_fp8_per_group_colwise_scales (
121
- input_col_major ,
122
- offs ,
123
- output_dtype = torch .float8_e4m3fn ,
124
- round_scales_to_power_of_2 = True ,
125
- )
126
-
127
- # bench torch
128
- compiled_run_torch = torch .compile (run_torch )
129
- warmup (compiled_run_torch , input_row_major , input_col_major , offs )
130
- torch_time_us = benchmark_cuda_function_in_microseconds (
131
- compiled_run_torch , input_row_major , input_col_major , offs
100
+ # Bench torch per group colwise
101
+ torch_to_float8_per_group_colwise_c = torch .compile (
102
+ torch_to_float8_per_group_colwise
103
+ )
104
+ warmup (
105
+ torch_to_float8_per_group_colwise_c ,
106
+ input_tensor ,
107
+ offs ,
108
+ target_dtype = torch .float8_e4m3fn ,
109
+ )
110
+ torch_loop_time_us = benchmark_cuda_function_in_microseconds (
111
+ torch_to_float8_per_group_colwise_c ,
112
+ input_tensor ,
113
+ offs ,
114
+ target_dtype = torch .float8_e4m3fn ,
132
115
)
133
116
134
- # bench triton
135
- warmup (run_triton , input_row_major , input_col_major , offs )
117
+ # Bench triton per group colwise
118
+ warmup (
119
+ triton_fp8_per_group_colwise_scales ,
120
+ input_tensor ,
121
+ offs ,
122
+ output_dtype = torch .float8_e4m3fn ,
123
+ round_scales_to_power_of_2 = True ,
124
+ )
136
125
triton_time_us = benchmark_cuda_function_in_microseconds (
137
- run_triton , input_row_major , input_col_major , offs
126
+ triton_fp8_per_group_colwise_scales ,
127
+ input_tensor ,
128
+ offs ,
129
+ output_dtype = torch .float8_e4m3fn ,
130
+ round_scales_to_power_of_2 = True ,
138
131
)
139
132
140
- # mem bw calculations - excluding scales to simplify calculation
141
- # but still get an accurate estimate.
133
+ # Mem bw calculations
142
134
bytes_per_input_el = torch .finfo (config .high_precision_dtype ).bits / 8
143
135
num_elements = input_tensor .numel ()
144
- read_bytes = num_elements * bytes_per_input_el
145
- write_bytes = num_elements # 1 byte per element in float8_e4m3fn
136
+ read_bytes = (
137
+ 2 * num_elements * bytes_per_input_el # read input tensor twice
138
+ + 4 * (n_groups * K ) # read scales tensor once, 4 bytes per fp32 scale
139
+ )
140
+ write_bytes = (
141
+ # 1 byte per output elem in fp8
142
+ num_elements
143
+ +
144
+ # write scales tensor, 4 bytes per fp32 scale (we actually do this write once per blong along the reduction dim using atomics, but this is an approximation)
145
+ 4 * (n_groups * K )
146
+ )
146
147
read_write_bytes = read_bytes + write_bytes
147
- torch_mem_bw_gbps = (read_write_bytes ) / (torch_time_us / 1e6 ) / 1e9
148
+ torch_mem_bw_gbps = (read_write_bytes ) / (torch_loop_time_us / 1e6 ) / 1e9
148
149
triton_mem_bw_gbps = (read_write_bytes ) / (triton_time_us / 1e6 ) / 1e9
149
150
150
151
return ExperimentResult (
151
- torch_time_us = torch_time_us ,
152
+ torch_loop_time_us = torch_loop_time_us ,
152
153
triton_time_us = triton_time_us ,
153
154
torch_mem_bw_gbps = torch_mem_bw_gbps ,
154
155
triton_mem_bw_gbps = triton_mem_bw_gbps ,
@@ -157,10 +158,10 @@ def run_triton(
157
158
158
159
def print_results (experiments : List [Experiment ]):
159
160
headers = [
160
- "input_shape " ,
161
+ "Mg,K " ,
161
162
"n_groups" ,
162
163
"high_precision_dtype" ,
163
- "torch_time_us " ,
164
+ "torch_loop_time_us " ,
164
165
"triton_time_us" ,
165
166
"torch_mem_bw_gbps" ,
166
167
"triton_mem_bw_gbps" ,
@@ -176,18 +177,18 @@ def print_results(experiments: List[Experiment]):
176
177
input_shape ,
177
178
experiment .config .n_groups ,
178
179
experiment .config .high_precision_dtype ,
179
- experiment .result .torch_time_us ,
180
+ experiment .result .torch_loop_time_us ,
180
181
experiment .result .triton_time_us ,
181
182
round (experiment .result .torch_mem_bw_gbps , 3 ),
182
183
round (experiment .result .triton_mem_bw_gbps , 3 ),
183
- f"{ experiment .result .torch_time_us / experiment .result .triton_time_us :.2f} x" ,
184
+ f"{ experiment .result .torch_loop_time_us / experiment .result .triton_time_us :.2f} x" ,
184
185
]
185
186
)
186
187
print (tabulate (rows , headers = headers ))
187
188
188
189
189
- def benchmark_cuda_function_in_microseconds (f , * args ):
190
- return do_bench (lambda : f (* args ), return_mode = "median" ) * 1e3
190
+ def benchmark_cuda_function_in_microseconds (f , * args , ** kwargs ):
191
+ return do_bench (lambda : f (* args , ** kwargs ), return_mode = "median" ) * 1e3
191
192
192
193
193
194
def main ():
0 commit comments