Skip to content

Add gguf q4_k quantization #2001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/prototype/test_gguf_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from torchao.prototype.quantization.gguf import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate this btw by actually creating a gguf for a model and then run the resulting gguf file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't explored how to export yet, will do in next PR

GGUFQuantizedTensor,
GGUFWeightOnlyConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quant_primitives import choose_qparams_gguf
from torchao.quantization.utils import compute_error


class TestGGUFQuantization(unittest.TestCase):
def setUp(self):
torch.manual_seed(123)
self.input = torch.randn(2, 256, dtype=torch.float32)
self.n_blocks_per_superblock = 8
self.block_size = (1, 32)
self.dtype = torch.uint4

def test_choose_qparams_gguf(self):
(
super_block_scale_scale,
super_block_min_scale,
quantized_block_scale,
quantized_block_min,
) = choose_qparams_gguf(self.input, self.block_size, self.dtype)

assert super_block_scale_scale.shape, (2, 8)
assert super_block_min_scale.shape, (2, 8)
assert quantized_block_scale.shape, (2, 32)

def test_gguf_quantized_tensor_from_float(self):
gqt = GGUFQuantizedTensor.from_float(
self.input,
self.n_blocks_per_superblock,
self.dtype,
)

dequant = gqt.dequantize()

sqnr = compute_error(dequant, self.input)
self.assertGreater(sqnr, 30)

def test_quantize_api(self):
m = torch.nn.Sequential(torch.nn.Linear(256, 64))
quantize_(m, GGUFWeightOnlyConfig())
assert type(m[0].weight) == GGUFQuantizedTensor


if __name__ == "__main__":
unittest.main()
6 changes: 5 additions & 1 deletion torchao/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
return json.loads(json.dumps(config, cls=ConfigJSONEncoder))


ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api"}
ALLOWED_AO_MODULES = {
"torchao.quantization",
"torchao.sparsity.sparse_api",
"torchao.prototype.quantization",
}


def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
Expand Down
5 changes: 5 additions & 0 deletions torchao/prototype/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .gguf import GGUFWeightOnlyConfig

__all__ = [
"GGUFWeightOnlyConfig",
]
9 changes: 9 additions & 0 deletions torchao/prototype/quantization/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .api import GGUFWeightOnlyConfig
from .gguf_quantized_tensor import (
GGUFQuantizedTensor,
)

__all__ = [
"GGUFQuantizedTensor",
"GGUFWeightOnlyConfig",
]
52 changes: 52 additions & 0 deletions torchao/prototype/quantization/gguf/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.transform_module import register_quantize_module_handler

from .gguf_quantized_tensor import GGUFQuantizedTensor

__all__ = [
"GGUFWeightOnlyConfig",
]


@dataclass
class GGUFWeightOnlyConfig(AOBaseConfig):
dtype: torch.dtype = torch.uint4
n_blocks_per_superblock: int = 8


@register_quantize_module_handler(GGUFWeightOnlyConfig)
def _gguf_weight_only_transform(
module: torch.nn.Module,
config: GGUFWeightOnlyConfig,
):
"""
Applies gguf weight-only quantization to linear layers.

Args:
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8
it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
Returns:
Callable for quantization transformation.
"""
weight = module.weight
if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0):
return module

quantized_weight = GGUFQuantizedTensor.from_float(
weight,
n_blocks_per_superblock=config.n_blocks_per_superblock,
target_dtype=config.dtype,
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module
Loading
Loading