Skip to content

Commit 6340025

Browse files
ElizaWszolaProExpertProgmgoin
authored
[Performance] Move apply_w8a8_block_fp8_linear to an op class (#24666)
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 8c1c81a commit 6340025

File tree

14 files changed

+345
-205
lines changed

14 files changed

+345
-205
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from vllm import _custom_ops as ops
1919
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
20-
w8a8_block_fp8_matmul,
20+
w8a8_triton_block_scaled_mm,
2121
)
2222
from vllm.utils import FlexibleArgumentParser, cdiv
2323

@@ -158,7 +158,7 @@ def bench_fp8(
158158
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
159159
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
160160
),
161-
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
161+
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
162162
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
163163
),
164164
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1111
get_col_major_tma_aligned_tensor,
1212
per_token_group_quant_fp8,
13-
w8a8_block_fp8_matmul,
13+
w8a8_triton_block_scaled_mm,
1414
)
1515
from vllm.triton_utils import triton
1616
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
@@ -59,7 +59,7 @@ def deepgemm_gemm():
5959

6060
# === vLLM Triton Implementation ===
6161
def vllm_triton_gemm():
62-
return w8a8_block_fp8_matmul(A_vllm,
62+
return w8a8_triton_block_scaled_mm(A_vllm,
6363
B_vllm,
6464
A_scale_vllm,
6565
B_scale_vllm,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.config import VllmConfig
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1414
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
15-
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
15+
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
1616
from vllm.platforms import current_platform
1717
from vllm.utils import has_deep_gemm
1818
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
@@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
9090

9191
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
9292
out_dtype)
93-
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
93+
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
94+
out_dtype)
9495

9596
rel_diff = (torch.mean(
9697
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

tests/kernels/quantization/test_fp8_quant_group.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
(8, 513, 64), # Non-divisible (native only)
2121
])
2222
@pytest.mark.parametrize("seed", [42])
23+
@pytest.mark.parametrize("use_ue8m0", [True, False])
2324
@torch.inference_mode()
2425
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
25-
group_size: int, seed: int) -> None:
26+
group_size: int, seed: int,
27+
use_ue8m0: bool) -> None:
2628
"""Test QuantFP8 group quantization with various configurations.
2729
2830
Tests both CUDA and native implementations, column-major scales,
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
3840
group_shape = GroupShape(1, group_size)
3941
quant_op = QuantFP8(static=False,
4042
group_shape=group_shape,
41-
column_major_scales=False)
43+
column_major_scales=False,
44+
use_ue8m0=use_ue8m0)
4245

4346
# 1. Test native implementation (always available)
4447
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
4851
# 2. Test column-major scales configuration
4952
quant_op_col = QuantFP8(static=False,
5053
group_shape=group_shape,
51-
column_major_scales=True)
54+
column_major_scales=True,
55+
use_ue8m0=use_ue8m0)
5256
_, scales_col = quant_op_col.forward_native(x.clone())
53-
assert scales_col.shape == (expected_num_groups, batch_size)
57+
assert scales_col.shape == (batch_size, expected_num_groups)
58+
assert scales_col.stride(0) == 1
59+
assert scales_col.stride(1) == batch_size
60+
61+
# Test column-major scales consistency
62+
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
5463

5564
# 3. Test CUDA implementation (only for divisible dimensions)
5665
if is_divisible:
@@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
6877

6978

7079
@pytest.mark.parametrize("seed", [42])
80+
@pytest.mark.parametrize("use_ue8m0", [True, False])
7181
@torch.inference_mode()
72-
def test_quantfp8_group_multidimensional(seed: int) -> None:
82+
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
7383
current_platform.seed_everything(seed)
7484

7585
group_size = 64
@@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
8292
group_shape = GroupShape(1, group_size)
8393
quant_op = QuantFP8(static=False,
8494
group_shape=group_shape,
85-
column_major_scales=False)
95+
column_major_scales=False,
96+
use_ue8m0=use_ue8m0)
8697

8798
x_quant, scales = quant_op.forward_native(x_3d.clone())
8899
assert x_quant.shape == x_3d.shape
@@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
91102
# Test column_major_scales with multi-dim
92103
quant_op_col = QuantFP8(static=False,
93104
group_shape=group_shape,
94-
column_major_scales=True)
105+
column_major_scales=True,
106+
use_ue8m0=use_ue8m0)
95107
_, scales_col = quant_op_col.forward_native(x_3d.clone())
96108
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
97109

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from vllm.model_executor.layers.layernorm import (RMSNorm,
1818
dispatch_rocm_rmsnorm_func,
1919
fused_add_rms_norm, rms_norm)
20-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
21-
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
2220
from vllm.platforms import current_platform
2321

2422
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str):
111109
RMSNorm(1024).enabled()
112110

113111

114-
@pytest.mark.skipif(
115-
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
116-
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
117-
@pytest.mark.parametrize("use_cutlass", [True, False])
118-
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
119-
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
120-
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
121-
use_rocm_aiter_gemm_w8a8_blockscale: str,
122-
monkeypatch):
123-
124-
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
125-
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
126-
use_rocm_aiter_gemm_w8a8_blockscale)
127-
128-
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
129-
int(use_rocm_aiter_gemm_w8a8_blockscale)))
130-
block_scale_func = dispatch_w8a8_blockscale_func(
131-
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
132-
if use_cutlass:
133-
assert block_scale_func == cutlass_scaled_mm
134-
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
135-
use_rocm_aiter_gemm_w8a8_blockscale):
136-
assert block_scale_func == (
137-
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
138-
else:
139-
assert block_scale_func == w8a8_block_fp8_matmul
140-
141-
142112
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
143113
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
144114
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

tests/quantization/test_compressed_tensors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
1919
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
2020
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
21+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
22+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
23+
W8A8BlockFp8LinearOp)
2124
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2225
cutlass_fp4_supported)
2326
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
742745
perplexity = llm.generate_prompt_perplexity([prompt])[0]
743746
print(perplexity)
744747
assert perplexity <= exp_perplexity
748+
749+
750+
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
751+
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
752+
with vllm_runner(model_path) as llm:
753+
754+
fp8_dtype = current_platform.fp8_dtype()
755+
756+
def check_model(model):
757+
layer = model.model.layers[0]
758+
759+
qkv_proj = layer.self_attn.qkv_proj
760+
assert isinstance(qkv_proj.quant_method,
761+
CompressedTensorsLinearMethod)
762+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
763+
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
764+
W8A8BlockFp8LinearOp)
765+
766+
assert qkv_proj.weight.dtype is fp8_dtype
767+
assert qkv_proj.weight_scale.dtype is torch.float32
768+
assert len(qkv_proj.weight.shape) == 2
769+
assert len(qkv_proj.weight_scale.shape) == 2
770+
771+
input_quant_op = \
772+
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
773+
assert isinstance(input_quant_op, QuantFP8)
774+
assert input_quant_op._forward_method == input_quant_op.forward_cuda
775+
776+
llm.apply_model(check_model)
777+
778+
output = llm.generate_greedy("Hello my name is", max_tokens=20)
779+
assert output

vllm/config/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,23 @@ def __post_init__(self):
687687
# local attention.
688688
self.scheduler_config.disable_hybrid_kv_cache_manager = True
689689

690+
def has_blocked_weights():
691+
if self.quant_config is not None:
692+
if hasattr(self.quant_config, "weight_block_size"):
693+
return self.quant_config.weight_block_size is not None
694+
elif hasattr(self.quant_config, "has_blocked_weights"):
695+
return self.quant_config.has_blocked_weights()
696+
return False
697+
698+
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
699+
# On H100 the CUDA kernel is faster than
700+
# native implementation
701+
# https://github.com/vllm-project/vllm/issues/25094
702+
if has_blocked_weights():
703+
custom_ops = self.compilation_config.custom_ops
704+
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
705+
custom_ops.append("+quant_fp8")
706+
690707
def update_sizes_for_sequence_parallelism(self,
691708
possible_sizes: list) -> list:
692709
# remove the sizes that not multiple of tp_size when

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
644644
# If no matches, return None
645645
return None
646646

647+
def has_blocked_weights(self) -> bool:
648+
for scheme in self.target_scheme_map.values():
649+
weight_quant = scheme.get("weights")
650+
if (weight_quant is not None
651+
and weight_quant.strategy == QuantizationStrategy.BLOCK):
652+
return True
653+
return False
654+
647655
@staticmethod
648656
def supports_cutlass_24(
649657
weight_quant: Optional[QuantizationArgs],

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1212
CompressedTensorsScheme)
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
apply_fp8_block_linear, check_aiter_fp8_linear_support,
14+
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
1515
create_fp8_input_scale, create_fp8_scale_parameter,
1616
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
1717
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
@@ -41,16 +41,30 @@ def __init__(self, weight_quant: QuantizationArgs,
4141
self.strategy = weight_quant.strategy
4242
self.out_dtype = torch.get_default_dtype()
4343
self.is_static_input_scheme = is_static_input_scheme
44-
self.act_q_group_shape = GroupShape.PER_TENSOR \
45-
if is_static_input_scheme else GroupShape.PER_TOKEN
46-
self.fp8_linear = Fp8LinearOp(
47-
act_quant_static=self.is_static_input_scheme,
48-
act_quant_group_shape=self.act_q_group_shape)
4944

5045
self.weight_block_size = self.weight_quant.block_structure
46+
if self.weight_block_size is not None:
47+
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
48+
else:
49+
self.act_q_group_shape = GroupShape.PER_TENSOR \
50+
if is_static_input_scheme else GroupShape.PER_TOKEN
51+
5152
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
5253
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
5354

55+
if self.weight_block_size is not None:
56+
assert not self.is_static_input_scheme
57+
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
58+
weight_group_shape=GroupShape(*self.weight_block_size),
59+
act_quant_group_shape=self.act_q_group_shape,
60+
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
61+
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
62+
)
63+
else:
64+
self.fp8_linear = Fp8LinearOp(
65+
act_quant_static=self.is_static_input_scheme,
66+
act_quant_group_shape=self.act_q_group_shape)
67+
5468
@classmethod
5569
def get_min_capability(cls) -> int:
5670
# lovelace and up
@@ -141,13 +155,14 @@ def apply_weights(self,
141155
x: torch.Tensor,
142156
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
143157

144-
if layer.weight_block_size is not None:
145-
return apply_fp8_block_linear(
146-
layer,
158+
if self.weight_block_size is not None:
159+
return self.w8a8_block_fp8_linear.apply(
147160
input=x,
161+
weight=layer.weight,
162+
weight_scale=layer.weight_scale,
163+
input_scale=layer.input_scale,
148164
bias=bias,
149-
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
150-
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
165+
)
151166

152167
return self.fp8_linear.apply(input=x,
153168
weight=layer.weight,

vllm/model_executor/layers/quantization/deepgemm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs(
4343
return M, N, K, C
4444

4545

46-
def w8a8_block_fp8_matmul_deepgemm(
46+
def w8a8_deepgemm_block_scaled_mm(
4747
A: torch.Tensor,
4848
B: torch.Tensor,
4949
As: torch.Tensor,
@@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm(
5959
return C
6060

6161

62-
def w8a8_block_fp8_matmul_deepgemm_fake(
62+
def w8a8_deepgemm_block_scaled_mm_fake(
6363
A: torch.Tensor,
6464
B: torch.Tensor,
6565
As: torch.Tensor,
@@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake(
7373

7474

7575
direct_register_custom_op(
76-
op_name="w8a8_block_fp8_matmul_deepgemm",
77-
op_func=w8a8_block_fp8_matmul_deepgemm,
76+
op_name="w8a8_deepgemm_block_scaled_mm",
77+
op_func=w8a8_deepgemm_block_scaled_mm,
7878
mutates_args=[],
79-
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
79+
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
8080
dispatch_key=current_platform.dispatch_key,
8181
)

0 commit comments

Comments
 (0)