Skip to content

Commit a117d9d

Browse files
committed
Add AutoAWQ integration
1 parent 8804595 commit a117d9d

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

outlines/models/awq.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import TYPE_CHECKING, Optional
2+
3+
from .transformers import Transformer, TransformerTokenizer
4+
5+
if TYPE_CHECKING:
6+
from transformers import PreTrainedModel, PreTrainedTokenizer
7+
8+
9+
class AWQModel(Transformer):
10+
"""Represents a `transformers` model."""
11+
12+
def __init__(
13+
self,
14+
model: "PreTrainedModel",
15+
tokenizer: "PreTrainedTokenizer",
16+
):
17+
self.device = model.model.device
18+
self.model = model
19+
self.tokenizer = tokenizer
20+
21+
22+
def awq(
23+
model_name: str,
24+
fuse_layers: bool = True,
25+
device: Optional[str] = None,
26+
model_kwargs: dict = {},
27+
tokenizer_kwargs: dict = {},
28+
):
29+
try:
30+
from awq import AutoAWQForCausalLM
31+
except ImportError:
32+
raise ImportError(
33+
"The `autoawq` and `transformers` library needs to be installed in order to use `AutoAWQ` models."
34+
)
35+
36+
model_kwargs["fuse_layers"] = fuse_layers
37+
model_kwargs["safetensors"] = True
38+
39+
if device is not None:
40+
model_kwargs["device_map"] = device
41+
42+
model = AutoAWQForCausalLM.from_quantized(model_name, **model_kwargs)
43+
tokenizer = TransformerTokenizer(model_name, trust_remote_code=True)
44+
45+
return AWQModel(model, tokenizer)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ exclude=["examples"]
8787

8888
[[tool.mypy.overrides]]
8989
module = [
90+
"awq.*",
91+
"auto_gptq.*",
9092
"jinja2",
9193
"joblib.*",
9294
"jsonschema.*",

0 commit comments

Comments
 (0)