20
20
(8 , 513 , 64 ), # Non-divisible (native only)
21
21
])
22
22
@pytest .mark .parametrize ("seed" , [42 ])
23
+ @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
23
24
@torch .inference_mode ()
24
25
def test_quantfp8_group_functionality (batch_size : int , hidden_dim : int ,
25
- group_size : int , seed : int ) -> None :
26
+ group_size : int , seed : int ,
27
+ use_ue8m0 : bool ) -> None :
26
28
"""Test QuantFP8 group quantization with various configurations.
27
29
28
30
Tests both CUDA and native implementations, column-major scales,
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
38
40
group_shape = GroupShape (1 , group_size )
39
41
quant_op = QuantFP8 (static = False ,
40
42
group_shape = group_shape ,
41
- column_major_scales = False )
43
+ column_major_scales = False ,
44
+ use_ue8m0 = use_ue8m0 )
42
45
43
46
# 1. Test native implementation (always available)
44
47
x_quant_native , scales_native = quant_op .forward_native (x .clone ())
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
48
51
# 2. Test column-major scales configuration
49
52
quant_op_col = QuantFP8 (static = False ,
50
53
group_shape = group_shape ,
51
- column_major_scales = True )
54
+ column_major_scales = True ,
55
+ use_ue8m0 = use_ue8m0 )
52
56
_ , scales_col = quant_op_col .forward_native (x .clone ())
53
- assert scales_col .shape == (expected_num_groups , batch_size )
57
+ assert scales_col .shape == (batch_size , expected_num_groups )
58
+ assert scales_col .stride (0 ) == 1
59
+ assert scales_col .stride (1 ) == batch_size
60
+
61
+ # Test column-major scales consistency
62
+ assert torch .allclose (scales_col , scales_native , rtol = 1e-9 , atol = 1e-8 )
54
63
55
64
# 3. Test CUDA implementation (only for divisible dimensions)
56
65
if is_divisible :
@@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
68
77
69
78
70
79
@pytest .mark .parametrize ("seed" , [42 ])
80
+ @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
71
81
@torch .inference_mode ()
72
- def test_quantfp8_group_multidimensional (seed : int ) -> None :
82
+ def test_quantfp8_group_multidimensional (seed : int , use_ue8m0 : bool ) -> None :
73
83
current_platform .seed_everything (seed )
74
84
75
85
group_size = 64
@@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
82
92
group_shape = GroupShape (1 , group_size )
83
93
quant_op = QuantFP8 (static = False ,
84
94
group_shape = group_shape ,
85
- column_major_scales = False )
95
+ column_major_scales = False ,
96
+ use_ue8m0 = use_ue8m0 )
86
97
87
98
x_quant , scales = quant_op .forward_native (x_3d .clone ())
88
99
assert x_quant .shape == x_3d .shape
@@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
91
102
# Test column_major_scales with multi-dim
92
103
quant_op_col = QuantFP8 (static = False ,
93
104
group_shape = group_shape ,
94
- column_major_scales = True )
105
+ column_major_scales = True ,
106
+ use_ue8m0 = use_ue8m0 )
95
107
_ , scales_col = quant_op_col .forward_native (x_3d .clone ())
96
108
assert scales_col .shape == (batch1 , hidden_dim // group_size , batch2 )
97
109
0 commit comments