From 36c0c25720286cf82fad4cffcc627dc7b98c7efc Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 2 Jun 2025 12:26:22 +0000 Subject: [PATCH 01/14] fix get_plain() with FMA mode --- torchao/dtypes/uintx/gemlite_layout.py | 41 ++++++++++++++++++++------ 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 1c840f7ec4..56b55c72a4 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -201,10 +201,13 @@ def from_plain( int_data, scale, zero_point, bit_width, group_size, bias=None ) + meta_args = gemlite_linear.get_meta_args() gemlite_kwargs = { "in_features": in_features, "out_features": out_features, - "meta_args": gemlite_linear.get_meta_args(), + "data_contiguous": meta_args[-1], + "W_group_mode": meta_args[10], + "meta_args": meta_args, } packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() @@ -235,18 +238,38 @@ def _apply_fn_to_data(self, fn): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = self.packed_weight.device int_data = ( - gemlite.bitpack.unpack_over_rows( - self.packed_weight.cuda(), - W_nbits=self._layout.bit_width, - num_output_rows=self.gemlite_kwargs["out_features"], - dtype=torch.uint8, + ( + gemlite.bitpack.unpack_over_rows( + self.packed_weight.cuda(), + W_nbits=self._layout.bit_width, + num_output_rows=self.gemlite_kwargs["in_features"], + dtype=torch.uint8, + ) ) + .to(device) .t() - .contiguous() - ).to(device) + ) + + if self.gemlite_kwargs["data_contiguous"]: + int_data = int_data.contiguous() + + # Handle FMA mode: W_q * s + z -> (W_q - z) * s + if self.gemlite_kwargs["W_group_mode"] == 4: + scale_min_val = 1e-8 + scale = self.scale.float() + scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( + scale_min_val + ) + scale[ + torch.logical_and(scale < 0, scale.abs() <= scale_min_val) + ] = -scale_min_val + zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100) + zero_point = zero_point.to(self.scale.dtype) + else: + zero_point = self.zero_point scale = self.scale.t().contiguous() - zero_point = self.zero_point.t().contiguous() + zero_point = zero_point.t().contiguous() return int_data, scale, zero_point From 5cc70e14e308b0bfe8885169df05f83bfd2a42fc Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 2 Jun 2025 14:36:53 +0000 Subject: [PATCH 02/14] update --- torchao/dtypes/uintx/gemlite_layout.py | 36 +++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 56b55c72a4..07278c61e1 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -206,6 +206,7 @@ def from_plain( "in_features": in_features, "out_features": out_features, "data_contiguous": meta_args[-1], + "elements_per_sample": meta_args[4], "W_group_mode": meta_args[10], "meta_args": meta_args, } @@ -250,13 +251,14 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: .t() ) + # Preserve col-row major layout if self.gemlite_kwargs["data_contiguous"]: int_data = int_data.contiguous() # Handle FMA mode: W_q * s + z -> (W_q - z) * s if self.gemlite_kwargs["W_group_mode"] == 4: scale_min_val = 1e-8 - scale = self.scale.float() + scale = self.scale.clone().float() scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( scale_min_val ) @@ -297,14 +299,29 @@ def __torch_dispatch__(cls, func, types, args, kwargs): assert step == 1, "Only step == 1 is supported in slicing right now" if dim in [0, 1]: - int_data, scale, zero_point = self.get_plain() - data_len = int_data.shape[dim] + # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T + dim = 1 - dim + packed_weight = self.packed_weight + scale = self.scale + zero_point = self.zero_point + + orig_shape = [ + self.gemlite_kwargs["in_features"], + self.gemlite_kwargs["out_features"], + ] + elements_per_sample = self.gemlite_kwargs["elements_per_sample"] + data_len = orig_shape[dim] scale_len = scale.shape[dim] ratio = data_len / scale_len start_scale = int(start / ratio) end_scale = int(end / ratio) - int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # For packing only the K dimension. This should be flipped for N-dim packing. + div = elements_per_sample if dim == 0 else 1 + packed_weight = aten.slice.Tensor( + packed_weight, dim, start // div, end // div, step + ) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) if zero_point is not None and zero_point.numel() > 0: zero_point = aten.slice.Tensor( @@ -312,15 +329,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) else: zero_point = None - # this is to handle padding - int_data, scale, zero_point = self._layout.post_process( - int_data, scale, zero_point, self.block_size - ) - - sliced = self.from_plain( - int_data, scale, zero_point, self._layout - ) # Will be transposed again + sliced = GemliteAQTTensorImpl( + packed_weight, scale, zero_point, self.gemlite_kwargs, self._layout + ) return return_and_correct_aliasing(func, args, kwargs, sliced) else: From 9ac689e7a70ce55f58f6cedc12bc4165f5305225 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Jun 2025 10:40:34 +0000 Subject: [PATCH 03/14] fix in_features/out_feature meta-data mismatch --- torchao/dtypes/uintx/gemlite_layout.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 07278c61e1..6b38e57fb1 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -25,7 +25,6 @@ except: gemlite = None - aten = torch.ops.aten @@ -35,8 +34,7 @@ def _same_metadata( ) -> bool: kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) for k, v in self.gemlite_kwargs.items(): - if k != "scale_activations": - kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) + kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) return ( isinstance(self, GemliteAQTTensorImpl) @@ -305,11 +303,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): scale = self.scale zero_point = self.zero_point + gemlite_kwargs = self.gemlite_kwargs.copy() orig_shape = [ - self.gemlite_kwargs["in_features"], - self.gemlite_kwargs["out_features"], + gemlite_kwargs["in_features"], + gemlite_kwargs["out_features"], ] - elements_per_sample = self.gemlite_kwargs["elements_per_sample"] + elements_per_sample = gemlite_kwargs["elements_per_sample"] data_len = orig_shape[dim] scale_len = scale.shape[dim] ratio = data_len / scale_len @@ -322,6 +321,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): packed_weight, dim, start // div, end // div, step ) + # Update in_features/out_features + gemlite_kwargs["in_features"] = ( + packed_weight.shape[0] * elements_per_sample + ) + gemlite_kwargs["out_features"] = packed_weight.shape[1] + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) if zero_point is not None and zero_point.numel() > 0: zero_point = aten.slice.Tensor( @@ -331,7 +336,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): zero_point = None sliced = GemliteAQTTensorImpl( - packed_weight, scale, zero_point, self.gemlite_kwargs, self._layout + packed_weight, scale, zero_point, gemlite_kwargs, self._layout ) return return_and_correct_aliasing(func, args, kwargs, sliced) From bece806cdf057f111bf6d73cdddb5b999b406c4a Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Jun 2025 10:50:46 +0000 Subject: [PATCH 04/14] update gemlite slice test --- test/dtypes/test_affine_quantized.py | 73 +++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index b74c5d2ecf..a5c40760a4 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -364,12 +364,81 @@ def test_slice_gemlite(self, device, dtype): # in_feature not divisible by 1024 # out_feature not divisible by 8 # to test slice + padding for int4 weight only quantization - dummy = nn.Linear(256, 512, dtype=dtype, device=device) - quantize_(dummy, GemliteUIntXWeightOnlyConfig()) + in_features, out_features, group_size, bit_width = 256, 512, 64, 4 + orig_shape = [out_features, in_features] + dummy = nn.Linear( + in_features, out_features, bias=False, dtype=dtype, device=device + ) + quantize_( + dummy, + GemliteUIntXWeightOnlyConfig(bit_width=bit_width, group_size=group_size), + ) + W_group_mode = dummy.weight.tensor_impl.gemlite_kwargs["meta_args"][10] + # make sure these run without error _ = dummy.weight.narrow(0, 0, 64) _ = dummy.weight.narrow(1, 0, 128) + # Dequant op + import gemlite + + def dequant(input_layer, in_features, orig_shape): + int_data = input_layer.tensor_impl.packed_weight + scale = input_layer.tensor_impl.scale + zero_point = input_layer.tensor_impl.zero_point + + W_q = ( + gemlite.bitpack.unpack_over_rows( + int_data, + W_nbits=bit_width, + num_output_rows=in_features, + dtype=torch.uint8, + ) + .T.contiguous() + .view([-1, group_size]) + ) + + s = scale.t().contiguous().view(-1, 1) + z = zero_point.t().contiguous().view(-1, 1) + + if W_group_mode == 4: # FMA + W_deq = (W_q * s + z).view(orig_shape) + else: + W_deq = ((W_q - z) * s).view(orig_shape) + + return W_deq + + W_r = dequant(dummy.weight, dummy.in_features, orig_shape) + + # Slicing in half + for slice_axis, start, end in [ + (0, 0, 256), + (0, 256, 256), + (1, 0, 128), + (1, 128, 128), + ]: + layer_sliced = dummy.weight.narrow(slice_axis, start, end) + + if slice_axis == 0: + num_rows, out_shape = ( + dummy.in_features, + (orig_shape[0] // 2, orig_shape[1]), + ) + else: + num_rows, out_shape = ( + dummy.in_features // 2, + (orig_shape[0], orig_shape[1] // 2), + ) + + W_slice = dequant(layer_sliced, num_rows, out_shape) + + W_slice_ref = ( + W_r[start : start + end, :] + if slice_axis == 0 + else W_r[:, start : start + end] + ) + self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0) + @common_utils.parametrize("device", ["cuda"]) @common_utils.parametrize("dtype", [torch.bfloat16]) def test_matmul(self, device, dtype): From ba7b4f1c05ae69f3485f3d650bb5bba07715a401 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Jun 2025 17:55:25 +0000 Subject: [PATCH 05/14] add packing_bitwidth support --- torchao/dtypes/uintx/gemlite_layout.py | 52 ++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 6b38e57fb1..73f2ca424b 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -28,13 +28,47 @@ aten = torch.ops.aten +import logging + +logger = logging.getLogger(__name__) +logger.error("************* INIT ************") + + def _same_metadata( self: "GemliteAQTTensorImpl", src: "GemliteAQTTensorImpl", ) -> bool: + # return True + kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) for k, v in self.gemlite_kwargs.items(): kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) + logger.error(str(k) + " | " + str(v) + " vs. " + str(src.gemlite_kwargs[k])) + + logger.error( + "self.packed_weight.shape" + + " | " + + str(self.packed_weight.shape) + + " vs. " + + str(src.packed_weight.shape) + ) + logger.error( + "self.scale.shape" + + " | " + + str(self.scale.shape) + + " vs. " + + str(src.scale.shape) + ) + logger.error( + "self.zero_point.shape" + + " | " + + str(self.zero_point.shape) + + " vs. " + + str(src.zero_point.shape) + ) + logger.error( + "----------------------------------------------------------------------------------------------------------" + ) return ( isinstance(self, GemliteAQTTensorImpl) @@ -78,6 +112,7 @@ def get_gemlite_aqt_kwargs( weight, group_size=64, bit_width=4, + packing_bitwidth=None, use_hqq=True, ): if gemlite is None: @@ -97,6 +132,9 @@ def get_gemlite_aqt_kwargs( assert group_size is None or bit_width != 8, ( "gemlite only works with group_size=None for bit_width=8" ) + assert packing_bitwidth in [8, 16, 32, None], ( + f"Invalid packing bitwidth, got {packing_bitwidth}" + ) out_features, in_features = weight.shape group_size = in_features if group_size is None else group_size @@ -105,6 +143,7 @@ def get_gemlite_aqt_kwargs( aqt_kwargs["_layout"] = GemlitePackedLayout( group_size=group_size, bit_width=bit_width, + packing_bitwidth=packing_bitwidth, ) aqt_kwargs["use_hqq"] = use_hqq return aqt_kwargs @@ -114,6 +153,7 @@ def get_gemlite_aqt_kwargs( class GemlitePackedLayout(Layout): group_size: Optional[int] = 64 bit_width: int = 4 + packing_bitwidth: Optional[int] = 32 @register_layout(GemlitePackedLayout) @@ -189,13 +229,16 @@ def from_plain( group_size, bit_width = _layout.group_size, _layout.bit_width out_features, in_features = int_data.shape + packing_bitwidth = _layout.packing_bitwidth if bit_width == 8 and group_size == in_features: - gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights( - int_data, scales=scale, bias=None - ) + gemlite_linear = gemlite.helper.A16W8( + device=int_data.device, packing_bitwidth=packing_bitwidth + ).from_weights(int_data, scales=scale, bias=None) else: - gemlite_linear = gemlite.helper.A16Wn(device=int_data.device).from_weights( + gemlite_linear = gemlite.helper.A16Wn( + device=int_data.device, packing_bitwidth=packing_bitwidth + ).from_weights( int_data, scale, zero_point, bit_width, group_size, bias=None ) @@ -203,6 +246,7 @@ def from_plain( gemlite_kwargs = { "in_features": in_features, "out_features": out_features, + "packing_bitwidth": packing_bitwidth, "data_contiguous": meta_args[-1], "elements_per_sample": meta_args[4], "W_group_mode": meta_args[10], From 33e2bf667630a49e7236175b796b8bea0c1b60c0 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 3 Jun 2025 18:00:31 +0000 Subject: [PATCH 06/14] add packing_bitwidth support and cleanup --- torchao/dtypes/uintx/gemlite_layout.py | 34 -------------------------- torchao/quantization/autoquant.py | 8 +++--- torchao/quantization/quant_api.py | 6 ++++- 3 files changed, 9 insertions(+), 39 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 73f2ca424b..a0641c031d 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -28,47 +28,13 @@ aten = torch.ops.aten -import logging - -logger = logging.getLogger(__name__) -logger.error("************* INIT ************") - - def _same_metadata( self: "GemliteAQTTensorImpl", src: "GemliteAQTTensorImpl", ) -> bool: - # return True - kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) for k, v in self.gemlite_kwargs.items(): kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) - logger.error(str(k) + " | " + str(v) + " vs. " + str(src.gemlite_kwargs[k])) - - logger.error( - "self.packed_weight.shape" - + " | " - + str(self.packed_weight.shape) - + " vs. " - + str(src.packed_weight.shape) - ) - logger.error( - "self.scale.shape" - + " | " - + str(self.scale.shape) - + " vs. " - + str(src.scale.shape) - ) - logger.error( - "self.zero_point.shape" - + " | " - + str(self.zero_point.shape) - + " vs. " - + str(src.zero_point.shape) - ) - logger.error( - "----------------------------------------------------------------------------------------------------------" - ) return ( isinstance(self, GemliteAQTTensorImpl) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 41ea588231..6daeb60f13 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -740,12 +740,12 @@ def from_float(cls, weight): if weight.dtype != torch.float16: weight = weight.to(torch.float16) - bit_width = 4 - packing_bitwidth = 32 - contiguous = None + bit_width = (4,) + packing_bitwidth = None use_hqq = True + aqt_kwargs = get_gemlite_aqt_kwargs( - weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq + weight, cls.group_size, bit_width, packing_bitwidth, use_hqq ) weight = to_affine_quantized_intx(weight, **aqt_kwargs) input_quant_func = _to_float16 diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a7eec7e1df..adf4ae014c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -988,6 +988,7 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): group_size: Optional[int] = 64 bit_width: int = 4 + packing_bitwidth: Optional[int] = 32 set_inductor_config: bool = True @@ -1001,6 +1002,7 @@ def _gemlite_uintx_weight_only_transform( ): group_size = config.group_size bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1011,7 +1013,9 @@ def _gemlite_uintx_weight_only_transform( use_hqq = True if bit_width == 4 else False new_weight = to_affine_quantized_intx( weight, - **get_gemlite_aqt_kwargs(weight, group_size, bit_width, use_hqq), + **get_gemlite_aqt_kwargs( + weight, group_size, bit_width, packing_bitwidth, use_hqq + ), ) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) From 587ab1026ec0b4b1a753eef33f1efa6749b2163a Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 4 Jun 2025 09:32:52 +0000 Subject: [PATCH 07/14] update default gemlite layout --- torchao/dtypes/uintx/gemlite_layout.py | 4 ++-- torchao/quantization/quant_api.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index a0641c031d..22f44a30ff 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -117,9 +117,9 @@ def get_gemlite_aqt_kwargs( @dataclass(frozen=True) class GemlitePackedLayout(Layout): - group_size: Optional[int] = 64 + group_size: Optional[int] = 128 bit_width: int = 4 - packing_bitwidth: Optional[int] = 32 + packing_bitwidth: Optional[int] = None @register_layout(GemlitePackedLayout) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index adf4ae014c..8fba0ed262 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -986,9 +986,9 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. """ - group_size: Optional[int] = 64 + group_size: Optional[int] = 128 bit_width: int = 4 - packing_bitwidth: Optional[int] = 32 + packing_bitwidth: Optional[int] = None set_inductor_config: bool = True From 1cb77949443ef4d73a05e0d7e4020a0263d21719 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 4 Jun 2025 10:21:00 +0000 Subject: [PATCH 08/14] cleanup --- torchao/dtypes/uintx/gemlite_layout.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 22f44a30ff..0b645c29cd 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -198,9 +198,9 @@ def from_plain( packing_bitwidth = _layout.packing_bitwidth if bit_width == 8 and group_size == in_features: - gemlite_linear = gemlite.helper.A16W8( - device=int_data.device, packing_bitwidth=packing_bitwidth - ).from_weights(int_data, scales=scale, bias=None) + gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights( + int_data, scales=scale, bias=None + ) else: gemlite_linear = gemlite.helper.A16Wn( device=int_data.device, packing_bitwidth=packing_bitwidth @@ -213,9 +213,9 @@ def from_plain( "in_features": in_features, "out_features": out_features, "packing_bitwidth": packing_bitwidth, - "data_contiguous": meta_args[-1], - "elements_per_sample": meta_args[4], - "W_group_mode": meta_args[10], + "data_contiguous": gemlite_linear.data_contiguous, + "elements_per_sample": gemlite_linear.elements_per_sample, + "W_group_mode": gemlite_linear.W_group_mode, "meta_args": meta_args, } From fc7ff5060f493153d09d1b07173cde50d63a2a1c Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 4 Jun 2025 14:25:55 +0000 Subject: [PATCH 09/14] fix symmetric use-case and relax _same_meta_data --- torchao/dtypes/uintx/gemlite_layout.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 0b645c29cd..d8d30e56a2 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -34,7 +34,13 @@ def _same_metadata( ) -> bool: kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) for k, v in self.gemlite_kwargs.items(): - kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) + if k in [ + "in_features", + "out_features", + "packing_bitwidth", + "elements_per_sample", + ]: + kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) return ( isinstance(self, GemliteAQTTensorImpl) @@ -221,6 +227,10 @@ def from_plain( packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() packed_weight = packed_weight.to(device) + if zero_point is None: + zero_point = torch.tensor( + [[]], device=packed_weight.device, dtype=torch.int32 + ) return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) @@ -358,6 +368,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.copy_.default: self = args[0] src = args[1] + + # Handle zero_point = None with symmetric quant + if self.zero_point is None: + self.zero_point = torch.tensor( + [[]], device=self.packed_weight.device, dtype=torch.int32 + ) + + if src.zero_point is None: + src.zero_point = torch.tensor( + [[]], device=src.packed_weight.device, dtype=torch.int32 + ) + if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: From 2d66fb4c0e70682e4e487d9fd8834089d1d9c093 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 4 Jun 2025 15:28:02 +0000 Subject: [PATCH 10/14] _copy() meta data --- torchao/dtypes/uintx/gemlite_layout.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index d8d30e56a2..eb06cf2a96 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -384,6 +384,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + for key in self.gemlite_kwargs: + self.gemlite_kwargs[key] = src.gemlite_kwargs[key] return raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" From eba10ad9272814b883141a8c92092f0568c5ec8a Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 4 Jun 2025 15:39:43 +0000 Subject: [PATCH 11/14] fix (4,) in autoquant --- torchao/quantization/autoquant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 6daeb60f13..998204c8fe 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -740,7 +740,7 @@ def from_float(cls, weight): if weight.dtype != torch.float16: weight = weight.to(torch.float16) - bit_width = (4,) + bit_width = 4 packing_bitwidth = None use_hqq = True From 9c7d41d54e7b37b8fa1e08141d1c7012d09b6d63 Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Jun 2025 12:43:17 +0000 Subject: [PATCH 12/14] Add dynamic mode in gemlite layout --- torchao/dtypes/uintx/gemlite_layout.py | 22 ++++++++++++++++++++-- torchao/quantization/autoquant.py | 8 +++++++- torchao/quantization/quant_api.py | 11 +++++++++-- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index eb06cf2a96..5ed3e5da8a 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -85,6 +85,7 @@ def get_gemlite_aqt_kwargs( group_size=64, bit_width=4, packing_bitwidth=None, + mode="static", use_hqq=True, ): if gemlite is None: @@ -108,6 +109,10 @@ def get_gemlite_aqt_kwargs( f"Invalid packing bitwidth, got {packing_bitwidth}" ) + assert mode in ["static", "dynamic"], ( + f"Invalid mode: should be either static or dynamic, got {mode}" + ) + out_features, in_features = weight.shape group_size = in_features if group_size is None else group_size @@ -116,6 +121,7 @@ def get_gemlite_aqt_kwargs( group_size=group_size, bit_width=bit_width, packing_bitwidth=packing_bitwidth, + mode=mode, ) aqt_kwargs["use_hqq"] = use_hqq return aqt_kwargs @@ -126,6 +132,7 @@ class GemlitePackedLayout(Layout): group_size: Optional[int] = 128 bit_width: int = 4 packing_bitwidth: Optional[int] = None + mode: Optional[str] = "static" @register_layout(GemlitePackedLayout) @@ -202,13 +209,24 @@ def from_plain( group_size, bit_width = _layout.group_size, _layout.bit_width out_features, in_features = int_data.shape packing_bitwidth = _layout.packing_bitwidth + mode = _layout.mode if bit_width == 8 and group_size == in_features: - gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights( + processor = ( + gemlite.helper.A8W8_int8_dynamic + if mode == "dynamic" + else gemlite.helper.A16W8 + ) + gemlite_linear = processor(device=int_data.device).from_weights( int_data, scales=scale, bias=None ) else: - gemlite_linear = gemlite.helper.A16Wn( + processor = ( + gemlite.helper.A8Wn_dynamic + if mode == "dynamic" + else gemlite.helper.A16Wn + ) + gemlite_linear = processor( device=int_data.device, packing_bitwidth=packing_bitwidth ).from_weights( int_data, scale, zero_point, bit_width, group_size, bias=None diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 998204c8fe..a258d7a5e4 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -742,10 +742,16 @@ def from_float(cls, weight): bit_width = 4 packing_bitwidth = None + mode = "static" use_hqq = True aqt_kwargs = get_gemlite_aqt_kwargs( - weight, cls.group_size, bit_width, packing_bitwidth, use_hqq + weight, + group_size=cls.group_size, + bit_width=bit_width, + packing_bitwidth=packing_bitwidth, + mode=mode, + use_hqq=use_hqq, ) weight = to_affine_quantized_intx(weight, **aqt_kwargs) input_quant_func = _to_float16 diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index be25b144a6..50c12acc47 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -986,13 +986,14 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): size is more fine grained `bit_width`: bit width of the quantized weight. `packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware. - `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. + `mode`: if set to "dyanmic", the activations will be dynamically quantized. `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. """ group_size: Optional[int] = 128 bit_width: int = 4 packing_bitwidth: Optional[int] = None + mode: Optional[str] = "static" set_inductor_config: bool = True @@ -1007,6 +1008,7 @@ def _gemlite_uintx_weight_only_transform( group_size = config.group_size bit_width = config.bit_width packing_bitwidth = config.packing_bitwidth + mode = config.mode if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1018,7 +1020,12 @@ def _gemlite_uintx_weight_only_transform( new_weight = to_affine_quantized_intx( weight, **get_gemlite_aqt_kwargs( - weight, group_size, bit_width, packing_bitwidth, use_hqq + weight, + group_size=group_size, + bit_width=bit_width, + packing_bitwidth=packing_bitwidth, + mode=mode, + use_hqq=use_hqq, ), ) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) From 6c7537b93e2ea21bf66552b1732f0a873bb4d58b Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Jun 2025 14:48:18 +0000 Subject: [PATCH 13/14] mode explanation Signed-off-by: mobicham --- torchao/quantization/quant_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 50c12acc47..dd9c60bf67 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -986,7 +986,7 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): size is more fine grained `bit_width`: bit width of the quantized weight. `packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware. - `mode`: if set to "dyanmic", the activations will be dynamically quantized. + `mode`: if set to "dynamic", activations are quantized at runtime; default is "static" (weight-only quantization). `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. """ From d11f3e2a982b04077d534a90b267c0b0540a279c Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 6 Jun 2025 15:03:40 +0000 Subject: [PATCH 14/14] use weights_only instead of static --- torchao/dtypes/uintx/gemlite_layout.py | 8 ++++---- torchao/quantization/autoquant.py | 2 +- torchao/quantization/quant_api.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 5ed3e5da8a..51b453de8a 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -85,7 +85,7 @@ def get_gemlite_aqt_kwargs( group_size=64, bit_width=4, packing_bitwidth=None, - mode="static", + mode="weight_only", use_hqq=True, ): if gemlite is None: @@ -109,8 +109,8 @@ def get_gemlite_aqt_kwargs( f"Invalid packing bitwidth, got {packing_bitwidth}" ) - assert mode in ["static", "dynamic"], ( - f"Invalid mode: should be either static or dynamic, got {mode}" + assert mode in ["weight_only", "dynamic"], ( + f"Invalid mode: should be either weight_only or dynamic, got {mode}" ) out_features, in_features = weight.shape @@ -132,7 +132,7 @@ class GemlitePackedLayout(Layout): group_size: Optional[int] = 128 bit_width: int = 4 packing_bitwidth: Optional[int] = None - mode: Optional[str] = "static" + mode: Optional[str] = "weight_only" @register_layout(GemlitePackedLayout) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a258d7a5e4..6f0aac947a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -742,7 +742,7 @@ def from_float(cls, weight): bit_width = 4 packing_bitwidth = None - mode = "static" + mode = "weight_only" use_hqq = True aqt_kwargs = get_gemlite_aqt_kwargs( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index dd9c60bf67..65c0814cf2 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -986,14 +986,14 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): size is more fine grained `bit_width`: bit width of the quantized weight. `packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware. - `mode`: if set to "dynamic", activations are quantized at runtime; default is "static" (weight-only quantization). + `mode`: if set to "dynamic", activations are quantized at runtime; default is "weight_only" (weight-only quantization). `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. """ group_size: Optional[int] = 128 bit_width: int = 4 packing_bitwidth: Optional[int] = None - mode: Optional[str] = "static" + mode: Optional[str] = "weight_only" set_inductor_config: bool = True