diff --git a/test/prototype/test_gguf_quant.py b/test/prototype/test_gguf_quant.py new file mode 100644 index 0000000000..b68d84b101 --- /dev/null +++ b/test/prototype/test_gguf_quant.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from torchao.prototype.quantization.gguf import ( + GGUFQuantizedTensor, + GGUFWeightOnlyConfig, +) +from torchao.quantization import quantize_ +from torchao.quantization.quant_primitives import choose_qparams_gguf +from torchao.quantization.utils import compute_error + + +class TestGGUFQuantization(unittest.TestCase): + def setUp(self): + torch.manual_seed(123) + self.input = torch.randn(2, 256, dtype=torch.float32) + self.n_blocks_per_superblock = 8 + self.block_size = (1, 32) + self.dtype = torch.uint4 + + def test_choose_qparams_gguf(self): + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) = choose_qparams_gguf(self.input, self.block_size, self.dtype) + + assert super_block_scale_scale.shape, (2, 8) + assert super_block_min_scale.shape, (2, 8) + assert quantized_block_scale.shape, (2, 32) + + def test_gguf_quantized_tensor_from_float(self): + gqt = GGUFQuantizedTensor.from_float( + self.input, + self.n_blocks_per_superblock, + self.dtype, + ) + + dequant = gqt.dequantize() + + sqnr = compute_error(dequant, self.input) + self.assertGreater(sqnr, 30) + + def test_quantize_api(self): + m = torch.nn.Sequential(torch.nn.Linear(256, 64)) + quantize_(m, GGUFWeightOnlyConfig()) + assert type(m[0].weight) == GGUFQuantizedTensor + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/core/config.py b/torchao/core/config.py index 4a5a4c5720..fe03ac225b 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -171,7 +171,11 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: return json.loads(json.dumps(config, cls=ConfigJSONEncoder)) -ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api"} +ALLOWED_AO_MODULES = { + "torchao.quantization", + "torchao.sparsity.sparse_api", + "torchao.prototype.quantization", +} def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: diff --git a/torchao/prototype/quantization/__init__.py b/torchao/prototype/quantization/__init__.py index e69de29bb2..bf49e2717b 100644 --- a/torchao/prototype/quantization/__init__.py +++ b/torchao/prototype/quantization/__init__.py @@ -0,0 +1,5 @@ +from .gguf import GGUFWeightOnlyConfig + +__all__ = [ + "GGUFWeightOnlyConfig", +] diff --git a/torchao/prototype/quantization/gguf/__init__.py b/torchao/prototype/quantization/gguf/__init__.py new file mode 100644 index 0000000000..3e43e1f3dc --- /dev/null +++ b/torchao/prototype/quantization/gguf/__init__.py @@ -0,0 +1,9 @@ +from .api import GGUFWeightOnlyConfig +from .gguf_quantized_tensor import ( + GGUFQuantizedTensor, +) + +__all__ = [ + "GGUFQuantizedTensor", + "GGUFWeightOnlyConfig", +] diff --git a/torchao/prototype/quantization/gguf/api.py b/torchao/prototype/quantization/gguf/api.py new file mode 100644 index 0000000000..bc4b46992a --- /dev/null +++ b/torchao/prototype/quantization/gguf/api.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import register_quantize_module_handler + +from .gguf_quantized_tensor import GGUFQuantizedTensor + +__all__ = [ + "GGUFWeightOnlyConfig", +] + + +@dataclass +class GGUFWeightOnlyConfig(AOBaseConfig): + dtype: torch.dtype = torch.uint4 + n_blocks_per_superblock: int = 8 + + +@register_quantize_module_handler(GGUFWeightOnlyConfig) +def _gguf_weight_only_transform( + module: torch.nn.Module, + config: GGUFWeightOnlyConfig, +): + """ + Applies gguf weight-only quantization to linear layers. + + Args: + dtype: torch.uint1 to torch.uint8, torch.int32 supported. + n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 + it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. + Returns: + Callable for quantization transformation. + """ + weight = module.weight + if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): + return module + + quantized_weight = GGUFQuantizedTensor.from_float( + weight, + n_blocks_per_superblock=config.n_blocks_per_superblock, + target_dtype=config.dtype, + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + return module diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py new file mode 100644 index 0000000000..0bb7b9a623 --- /dev/null +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + choose_qparams_gguf, + dequantize_gguf, + quantize_gguf, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) + +_QK_K = 256 +aten = torch.ops.aten + +__all__ = [ + "GGUFQuantizedTensor", +] + + +class GGUFQuantizedTensor(TorchAOBaseTensor): + """ + A Tensor subclass that when applied to a weight used in a linear op/module, + changes that linear op to a weight-only int4 quantized linear op with groupwise + affine quantization on the weight. + """ + + @staticmethod + def __new__( + cls, + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape, + **kwargs, + ): + kwargs["device"] = kwargs.get("device", super_block_scale_scale.device) + kwargs["dtype"] = kwargs.get("dtype", super_block_scale_scale.dtype) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape, + **kwargs, + ): + self.n_blocks_per_superblock = n_blocks_per_superblock + self.super_block_scale_scale = super_block_scale_scale + self.super_block_min_scale = super_block_min_scale + self.quantized_block_scale = quantized_block_scale + self.quantized_block_min = quantized_block_min + self.int_data = int_data + + def _apply_fn_to_data(self, fn): + return self.__class__( + self.n_blocks_per_superblock, + fn(self.super_block_scale_scale), + fn(self.super_block_min_sclae), + fn(self.quantized_block_scale), + fn(self.quantized_block_min), + fn(self.int_data), + self.shape, + dtype=self.dtype, + ) + + def __tensor_flatten__(self): + return [ + "super_block_scale_scale", + "super_block_min_scale", + "quantized_block_scale", + "quantized_block_min", + "int_data", + ], ( + self.n_blocks_per_superblock, + self.dtype, + self.shape, + ) + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None + ): + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + ) = ( + tensor_data_dict["super_block_scale_scale"], + tensor_data_dict["super_block_min_scale"], + tensor_data_dict["quantized_block_scale"], + tensor_data_dict["quantized_block_min"], + tensor_data_dict["int_data"], + ) + n_blocks_per_superblock, dtype, shape = attributes + return cls( + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + shape if outer_size is None else outer_size, + dtype=dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + block_size = tuple( + [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock] + ) + return dequantize_gguf( + self.int_data, + block_size, + self.dtype, + self.super_block_scale_scale, + self.super_block_min_scale, + self.quantized_block_scale, + self.quantized_block_min, + output_dtype=output_dtype, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.n_blocks_per_superblock, + self.super_block_scale_scale.to(device), + self.super_block_min_scale.to(device), + self.quantized_block_scale.to(device), + self.quantized_block_min.to(device), + self.int_data.to(device), + self.shape, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + """ + Returns a new `CodebookQuantizedTensor`. + """ + return self.__class__( + self.n_blocks_per_superblock, + fn(self.super_block_scale_scale), + fn(self.super_block_min_scale), + fn(self.quantized_block_scale), + fn(self.quantized_block_min), + fn(self.int_data), + self.shape, + dtype=self.dtype, + ) + + def requires_grad_(self, requires_grad=False): + """ + Modifies the tensor's `requires_grad` status in-place. + """ + assert not requires_grad, "Only requires_grad == False is supported" + return self + + @classmethod + def from_float(cls, input_float, n_blocks_per_superblock, target_dtype): + """ + Method used to convert a linear weight tensor to an instance of the + GGMLInt4LinearWeight subclass. + + Example usage:: + + model.lin_mod.weight = ( + GGMLInt4LinearWeight.from_float(model.lin_mod.weight) + ) + """ + assert ( + target_dtype == torch.uint4 + ), "only uint4 quantization is supported right now" + block_size = (1, _QK_K // n_blocks_per_superblock) + ( + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) = choose_qparams_gguf(input_float, block_size, target_dtype) + + int_data = quantize_gguf( + input_float, + block_size, + target_dtype, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + ) + return cls( + n_blocks_per_superblock, + super_block_scale_scale, + super_block_min_scale, + quantized_block_scale, + quantized_block_min, + int_data, + input_float.shape, + ) + + +implements = GGUFQuantizedTensor.implements + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + dtype = input_tensor.dtype + + if hasattr(weight_tensor, "dequantize"): + weight_tensor = weight_tensor.dequantize(output_dtype=dtype) + + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([GGUFQuantizedTensor]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 05be8c5c30..bc176c9d17 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -42,6 +42,9 @@ "choose_qparams_affine_float8", "quantize_affine_float8", "dequantize_affine_float8", + "choose_qparams_gguf", + "quantize_gguf", + "dequantize_gguf", ] @@ -195,6 +198,8 @@ class TorchAODType(Enum): _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys() +_GGUF_QK_K = 256 + _ONES_TABLE = [_n_ones(i) for i in range(8)] quant_lib = torch.library.Library("quant", "FRAGMENT") @@ -1039,6 +1044,214 @@ def reshape_w(w): return q_w, s_group, s_channel, w_ref +def choose_qparams_gguf( + input: Optional[torch.Tensor], + block_size: List[int], + target_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + There are two sets of qparams: quantized_block_scale, quantized_block_min and super_block_scale_scale and super_block_min_scale + the relationship is the following: + block_scale = quantized_block_scale * super_block_sclae + block_min = quantized_block_min * super_block_min + quantized_val = (float_val - block_min) / block_scale + quant_min + first we calculate block_scale and block_min + then we calculate super_block_scale_scale and super_block_min_scale + after that we can calculate quantized_block_scale and quantized_min_scale + the returned values are: super_block_scale_scale, super_block_min_scale, quantized_block_scale + and quantized_min_scale + """ + dtype = input.dtype + + # 1. get block_scale block_min + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + quant_max = 15 + quant_min = 0 + # asymmetric quant to fully utilize the range + block_scale = max_val / (float(quant_max - quant_min) / 2) + block_scale = (max_val - min_val) / float(quant_max - quant_min) + block_min = min_val + + # 2. get super_block_scale_scale and super_block_min_scale + assert _GGUF_QK_K % block_size[-1] == 0 + super_block_size = (1, _GGUF_QK_K // block_size[-1]) + shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, block_scale.size() + ) + block_scale = block_scale.view(shape_for_reduction) + block_min = block_min.view(shape_for_reduction) + + shape_after_reduction = shape_for_reduction.copy() + for i in reduction_dims: + shape_after_reduction[i] = 1 + + block_scale_absmax = torch.amax( + torch.abs(block_scale), dim=reduction_dims, keepdim=False + ) + block_min_absmax = torch.amax( + torch.abs(block_min), dim=reduction_dims, keepdim=False + ) + + # 2. get super_block_scale_scale and super_block_min_scale + # TODO: make this configurable + # we also quantize the quantization parameters (scale and min) for each block to 6 bit + # for Q4_K + qparam_quant_max = 2**6 - 1 + qparam_quant_min = 0 + super_block_scale_scale = block_scale_absmax / float( + qparam_quant_max - qparam_quant_min + ) + super_block_min_scale = block_min_absmax / float( + qparam_quant_max - qparam_quant_min + ) + super_block_scale_scale_view = super_block_scale_scale.view(shape_after_reduction) + super_block_min_scale_view = super_block_min_scale.view(shape_after_reduction) + + # 3. quantize block scale and min are stored in 6 bits using super_block_scale_scale and super_block_min_scale + quantized_block_scale = torch.clamp( + block_scale / super_block_scale_scale_view, qparam_quant_min, qparam_quant_max + ) + quantized_block_min = torch.clamp( + block_min / super_block_min_scale_view, qparam_quant_min, qparam_quant_max + ) + return ( + super_block_scale_scale.to(dtype), + super_block_min_scale.to(dtype), + quantized_block_scale.to(dtype), + quantized_block_min.to(dtype), + ) + + +def quantize_gguf( + input: torch.Tensor, + block_size: List[int], + target_dtype: torch.dtype, + super_block_scale_scale: torch.Tensor, + super_block_min_scale: torch.Tensor, + quantized_block_scale: torch.Tensor, + quantized_block_min: torch.Tensor, +) -> torch.Tensor: + assert target_dtype == torch.uint4 + + # step 1: first order quantization + # just going through shape calculation for block_scale and block_min to get the correct shape + input_shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + block_qparam_shape_after_reduction = input_shape_for_reduction.copy() + for i in reduction_dims: + block_qparam_shape_after_reduction[i] = 1 + original_shape = input.shape + input = input.view(input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view( + block_qparam_shape_after_reduction + ) + quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) + + # step 2: second order quantization, recover unquantized block_scale and block_min + super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) + super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, quantized_block_scale.size() + ) + super_block_qparam_shape_after_reduction = ( + super_block_input_shape_for_reduction.copy() + ) + for i in reduction_dims: + super_block_qparam_shape_after_reduction[i] = 1 + + quantized_block_scale = quantized_block_scale.view( + super_block_input_shape_for_reduction + ) + quantized_block_min = quantized_block_min.view( + super_block_input_shape_for_reduction + ) + super_block_scale_scale = super_block_scale_scale.view( + super_block_qparam_shape_after_reduction + ) + super_block_min_scale = super_block_min_scale.view( + super_block_qparam_shape_after_reduction + ) + + block_scale = super_block_scale_scale * quantized_block_scale + block_min = super_block_min_scale * quantized_block_min + + # step 3: quantization with the unquantized block_scale and block_min + block_scale = block_scale.view(block_qparam_shape_after_reduction) + block_min = block_min.view(block_qparam_shape_after_reduction) + int_data = (input - block_min) / block_scale + int_data = int_data.view(original_shape) + + return int_data + + +def dequantize_gguf( + input: torch.Tensor, + block_size: List[int], + target_dtype: torch.dtype, + super_block_scale_scale: torch.Tensor, + super_block_min_scale: torch.Tensor, + quantized_block_scale: torch.Tensor, + quantized_block_min: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + # step 1. reshape input and quantized block scale and min to the shape + # after first quantization + input_shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + block_qparam_shape_after_reduction = input_shape_for_reduction.copy() + for i in reduction_dims: + block_qparam_shape_after_reduction[i] = 1 + + original_shape = input.shape + input = input.view(input_shape_for_reduction) + quantized_block_scale = quantized_block_scale.view( + block_qparam_shape_after_reduction + ) + quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction) + + # step 2. calculate and reshape block_qparams for second quantization step + super_block_size = (1, _GGUF_QK_K // block_size[-1], 1) + super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params( + super_block_size, quantized_block_scale.size() + ) + super_block_qparam_shape_after_reduction = ( + super_block_input_shape_for_reduction.copy() + ) + for i in reduction_dims: + super_block_qparam_shape_after_reduction[i] = 1 + quantized_block_scale = quantized_block_scale.view( + super_block_input_shape_for_reduction + ) + quantized_block_min = quantized_block_min.view( + super_block_input_shape_for_reduction + ) + super_block_scale_scale = super_block_scale_scale.view( + super_block_qparam_shape_after_reduction + ) + super_block_min_scale = super_block_min_scale.view( + super_block_qparam_shape_after_reduction + ) + + block_scale = super_block_scale_scale * quantized_block_scale + block_min = super_block_min_scale * quantized_block_min + + # step 3. dequantize with block_scale and block_min + block_scale = block_scale.view(block_qparam_shape_after_reduction) + block_min = block_min.view(block_qparam_shape_after_reduction) + dequant = input * block_scale + block_min + dequant = dequant.view(original_shape) + if output_dtype is not None: + dequant = dequant.to(output_dtype) + + return dequant + + def dequantize_affine_qqq( w: torch.Tensor, s_group: torch.Tensor,