2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
import itertools
4
4
from typing import Callable
5
+ from unittest .mock import patch
5
6
7
+ import pandas as pd
6
8
import torch
7
9
8
- from vllm import _custom_ops as ops
9
- from vllm .config import CompilationConfig , VllmConfig , set_current_vllm_config
10
10
from vllm .model_executor .layers .quantization .input_quant_fp8 import QuantFP8
11
11
from vllm .model_executor .layers .quantization .utils .quant_utils import GroupShape
12
12
from vllm .triton_utils import triton
13
+ from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , FlexibleArgumentParser
14
+
15
+
16
+ def with_triton_mode (fn ):
17
+ """Temporarily force the Triton fallback path"""
18
+
19
+ def wrapped (* args , ** kwargs ):
20
+ with patch ("vllm.platforms.current_platform.is_cuda" , return_value = False ):
21
+ return fn (* args , ** kwargs )
22
+
23
+ return wrapped
13
24
14
25
15
26
# TODO(luka): use standalone_compile utility
@@ -21,78 +32,236 @@ def inner(*args):
21
32
return inner
22
33
23
34
24
- torch ._dynamo .config .recompile_limit = 8888
25
- compilation_config = CompilationConfig (custom_ops = ["none" ])
26
- with set_current_vllm_config (VllmConfig (compilation_config = compilation_config )):
27
- torch_per_token_quant_fp8 = torch .compile (
28
- QuantFP8 (False , GroupShape .PER_TOKEN ),
29
- fullgraph = True ,
30
- dynamic = False , # recompile for different shapes
31
- )
35
+ def bench_compile (fn : Callable ):
36
+ # recompile for different shapes
37
+ fwd = torch .compile (fn , fullgraph = True , dynamic = False )
32
38
33
39
# First dim is explicitly dynamic to simulate vLLM usage
34
- torch_per_token_quant_fp8 = with_dyn_arg (torch_per_token_quant_fp8 , 0 , 0 )
40
+ return with_dyn_arg (fwd , 0 , 0 )
35
41
36
42
37
- def cuda_per_token_quant_fp8 (
38
- input : torch .Tensor ,
39
- ) -> tuple [torch .Tensor , torch .Tensor ]:
40
- return ops .scaled_fp8_quant (input )
43
+ torch ._dynamo .config .recompile_limit = 8888
41
44
42
45
43
- def calculate_diff (batch_size : int , seq_len : int ):
44
- """Calculate difference between Triton and CUDA implementations."""
46
+ def calculate_diff (
47
+ batch_size : int ,
48
+ hidden_size : int ,
49
+ group_shape : GroupShape ,
50
+ dtype : torch .dtype ,
51
+ ):
52
+ """Calculate the difference between Inductor and CUDA implementations."""
45
53
device = torch .device ("cuda" )
46
- x = torch .rand ((batch_size * seq_len , 4096 ), dtype = torch .float16 , device = device )
54
+ x = torch .rand ((batch_size * hidden_size , 4096 ), dtype = dtype , device = device )
55
+
56
+ quant_fp8 = QuantFP8 (False , group_shape , column_major_scales = False )
47
57
48
- torch_out , torch_scale = torch_per_token_quant_fp8 (x )
49
- cuda_out , cuda_scale = cuda_per_token_quant_fp8 (x )
58
+ torch_out , torch_scale = bench_compile (quant_fp8 .forward_native )(x )
59
+ torch_eager_out , torch_eager_scale = quant_fp8 .forward_native (x )
60
+ cuda_out , cuda_scale = quant_fp8 .forward_cuda (x )
50
61
51
- if torch .allclose (
52
- cuda_out .to (torch .float32 ), torch_out .to (torch .float32 ), rtol = 1e-3 , atol = 1e-5
53
- ) and torch .allclose (cuda_scale , torch_scale , rtol = 1e-3 , atol = 1e-5 ):
62
+ out_allclose = lambda o1 , o2 : torch .allclose (
63
+ o1 .to (torch .float32 ),
64
+ o2 .to (torch .float32 ),
65
+ rtol = 1e-3 ,
66
+ atol = 1e-5 ,
67
+ )
68
+ scale_allclose = lambda s1 , s2 : torch .allclose (s1 , s2 , rtol = 1e-3 , atol = 1e-5 )
69
+
70
+ if (
71
+ out_allclose (cuda_out , torch_out )
72
+ and scale_allclose (cuda_scale , torch_scale )
73
+ and out_allclose (cuda_out , torch_eager_out )
74
+ and scale_allclose (cuda_scale , torch_eager_scale )
75
+ ):
54
76
print ("✅ All implementations match" )
55
77
else :
56
78
print ("❌ Implementations differ" )
57
79
58
80
59
- batch_size_range = [1 , 16 , 32 , 64 , 128 ]
60
- seq_len_range = [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
61
-
62
- configs = list (itertools .product (batch_size_range , seq_len_range ))
81
+ configs = []
63
82
64
83
65
- @triton .testing .perf_report (
66
- triton .testing .Benchmark (
67
- x_names = ["batch_size" , "seq_len" ],
68
- x_vals = configs ,
69
- line_arg = "provider" ,
70
- line_vals = ["torch" , "cuda" ],
71
- line_names = ["Torch" , "CUDA" ],
72
- styles = [("blue" , "-" ), ("green" , "-" )],
73
- ylabel = "us" ,
74
- plot_name = "per-token-dynamic-quant-fp8-performance" ,
75
- args = {},
76
- )
77
- )
78
- def benchmark_quantization (batch_size , seq_len , provider ):
79
- dtype = torch .float16
84
+ def benchmark_quantization (
85
+ batch_size ,
86
+ hidden_size ,
87
+ provider ,
88
+ group_shape : GroupShape ,
89
+ col_major : bool ,
90
+ dtype : torch .dtype ,
91
+ ):
80
92
device = torch .device ("cuda" )
81
93
82
- x = torch .randn (batch_size * seq_len , 4096 , device = device , dtype = dtype )
94
+ x = torch .randn (batch_size * hidden_size , 4096 , device = device , dtype = dtype )
83
95
84
96
quantiles = [0.5 , 0.2 , 0.8 ]
97
+ quant_fp8 = QuantFP8 (False , group_shape , column_major_scales = col_major )
85
98
86
99
if provider == "torch" :
87
- fn = lambda : torch_per_token_quant_fp8 (x .clone ())
100
+ fn = lambda : bench_compile ( quant_fp8 . forward_native ) (x .clone ())
88
101
elif provider == "cuda" :
89
- fn = lambda : cuda_per_token_quant_fp8 (x .clone ())
102
+ fn = lambda : quant_fp8 .forward_cuda (x .clone ())
103
+ elif provider == "triton" :
104
+ if not group_shape .is_per_group ():
105
+ # Triton only supported for per-group
106
+ return 0 , 0 , 0
107
+
108
+ fn = lambda : with_triton_mode (quant_fp8 .forward_cuda )(x .clone ())
90
109
91
110
ms , min_ms , max_ms = triton .testing .do_bench_cudagraph (fn , quantiles = quantiles )
92
111
93
112
return 1000 * ms , 1000 * max_ms , 1000 * min_ms
94
113
95
114
115
+ # TODO(luka) extract to utils
116
+ def compute_geomean_speedups (
117
+ df : pd .DataFrame ,
118
+ baseline_col : str ,
119
+ speedup_cols : list [str ],
120
+ groupby_cols : list [str ] | None = None ,
121
+ ) -> pd .DataFrame :
122
+ """
123
+ Compute geometric mean speedups over a baseline column.
124
+
125
+ Args:
126
+ df: Input dataframe
127
+ baseline_col: Column to use as baseline
128
+ speedup_cols: Columns to compute speedups for
129
+ groupby_cols: Columns to group by. If None, compute over entire df.
130
+
131
+ Returns:
132
+ pd.DataFrame with geometric mean speedups
133
+ """
134
+ from scipy .stats import gmean
135
+
136
+ def geo_speedup (group : pd .DataFrame ) -> pd .Series :
137
+ ratios = {
138
+ col : (group [baseline_col ] / group [col ]).values for col in speedup_cols
139
+ }
140
+ return pd .Series ({col : gmean (vals ) for col , vals in ratios .items ()})
141
+
142
+ if groupby_cols is None :
143
+ result = geo_speedup (df ).to_frame ().T
144
+ else :
145
+ result = (
146
+ df .groupby (groupby_cols )
147
+ .apply (geo_speedup , include_groups = False )
148
+ .reset_index ()
149
+ )
150
+
151
+ return result
152
+
153
+
96
154
if __name__ == "__main__" :
97
- calculate_diff (batch_size = 4 , seq_len = 4096 )
98
- benchmark_quantization .run (print_data = True )
155
+ parser = FlexibleArgumentParser (
156
+ description = "Benchmark the various implementations of QuantFP8 (dynamic-only)"
157
+ )
158
+ parser .add_argument ("-c" , "--check" , action = "store_true" )
159
+ parser .add_argument (
160
+ "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "half"
161
+ )
162
+ parser .add_argument (
163
+ "--hidden-sizes" ,
164
+ type = int ,
165
+ nargs = "+" ,
166
+ default = None ,
167
+ help = "Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)" ,
168
+ )
169
+ parser .add_argument (
170
+ "--batch-sizes" ,
171
+ type = int ,
172
+ nargs = "+" ,
173
+ default = None ,
174
+ help = "Batch sizes to benchmark (default: 1,16,32,64,128)" ,
175
+ )
176
+ parser .add_argument (
177
+ "--group-sizes" ,
178
+ type = int ,
179
+ nargs = "+" ,
180
+ default = None ,
181
+ help = "Group sizes for GroupShape(1,N) to benchmark. "
182
+ "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)" ,
183
+ )
184
+ parser .add_argument (
185
+ "--no-column-major" ,
186
+ action = "store_true" ,
187
+ help = "Disable column-major scales testing" ,
188
+ )
189
+
190
+ args = parser .parse_args ()
191
+ assert args
192
+
193
+ dtype = STR_DTYPE_TO_TORCH_DTYPE [args .dtype ]
194
+
195
+ hidden_sizes = args .hidden_sizes or [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
196
+ batch_sizes = args .batch_sizes or [1 , 16 , 32 , 64 , 128 ]
197
+
198
+ if args .group_sizes is not None :
199
+ group_shapes = []
200
+ for size in args .group_sizes :
201
+ if size == 0 :
202
+ group_shapes .append (GroupShape .PER_TENSOR )
203
+ elif size == - 1 :
204
+ group_shapes .append (GroupShape .PER_TOKEN )
205
+ else :
206
+ group_shapes .append (GroupShape (1 , size ))
207
+ else :
208
+ group_shapes = [
209
+ GroupShape .PER_TENSOR ,
210
+ GroupShape .PER_TOKEN ,
211
+ GroupShape (1 , 64 ),
212
+ GroupShape (1 , 128 ),
213
+ ]
214
+
215
+ column_major_scales = [False ] if args .no_column_major else [True , False ]
216
+
217
+ config_gen = itertools .product (
218
+ group_shapes ,
219
+ column_major_scales ,
220
+ batch_sizes ,
221
+ hidden_sizes ,
222
+ )
223
+
224
+ # filter out column-major scales for non-group, reverse order
225
+ configs .extend (c [::- 1 ] for c in config_gen if (c [0 ].is_per_group () or not c [1 ]))
226
+
227
+ print (f"Running { len (configs )} configurations:" )
228
+ print (f" Hidden sizes: { hidden_sizes } " )
229
+ print (f" Batch sizes: { batch_sizes } " )
230
+ print (f" Group shapes: { [str (g ) for g in group_shapes ]} " )
231
+ print (f" Column major scales: { column_major_scales } " )
232
+ print ()
233
+
234
+ if args .check :
235
+ for group_shape in group_shapes :
236
+ group_size = group_shape [1 ]
237
+ print (f"{ group_size = } " )
238
+ calculate_diff (
239
+ batch_size = 4 , hidden_size = 4096 , group_shape = group_shape , dtype = dtype
240
+ )
241
+
242
+ benchmark = triton .testing .perf_report (
243
+ triton .testing .Benchmark (
244
+ x_names = ["hidden_size" , "batch_size" , "col_major" , "group_shape" ],
245
+ x_vals = configs ,
246
+ line_arg = "provider" ,
247
+ line_vals = ["torch" , "cuda" , "triton" ],
248
+ line_names = ["Torch (Compiled)" , "CUDA" , "Triton" ],
249
+ styles = [("blue" , "-" ), ("green" , "-" ), ("black" , "-" )],
250
+ ylabel = "us" ,
251
+ plot_name = "QuantFP8 performance" ,
252
+ args = {},
253
+ )
254
+ )(benchmark_quantization )
255
+
256
+ df = benchmark .run (print_data = True , dtype = dtype , return_df = True )
257
+
258
+ # Print geomean speedups
259
+ geo_table_grouped = compute_geomean_speedups (
260
+ df ,
261
+ baseline_col = "Torch (Compiled)" ,
262
+ speedup_cols = ["CUDA" , "Triton" ],
263
+ groupby_cols = ["col_major" , "group_shape" ],
264
+ )
265
+
266
+ print ("Speedup over Torch (Compiled)" )
267
+ print (geo_table_grouped .to_string (index = False ))
0 commit comments