Skip to content

Commit e99d92d

Browse files
dtiarksrlouf
authored andcommitted
Integrate llama.cpp via a logits processor
1 parent bc71b23 commit e99d92d

File tree

18 files changed

+649
-349
lines changed

18 files changed

+649
-349
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ __pycache__
44
docs/build
55
.coverage
66
.idea/
7+
*.gguf

docs/reference/models/llamacpp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the
1111
```python
1212
from outlines import models, generate
1313

14-
model = models.llamacpp("./phi-2.Q4_K_M.gguf", device="cpu")
14+
model = models.llamacpp("./phi-2.Q4_K_M.gguf")
1515
```

examples/llamacpp_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class Character(BaseModel):
3030

3131

3232
if __name__ == "__main__":
33-
# Download model from https://huggingface.co/TheBloke/phi-2-GGUF
34-
model = outlines.models.llamacpp("./phi-2.Q3_K_M.gguf", device="cpu")
33+
# curl -L -o mistral-7b-instruct-v0.2.Q5_K_M.gguf https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q5_K_M.gguf
34+
model = outlines.models.llamacpp("./mistral-7b-instruct-v0.2.Q5_K_M.gguf")
3535

3636
# Construct structured sequence generator
3737
generator = outlines.generate.json(model, Character)

examples/llamacpp_processor.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from enum import Enum
2+
3+
from llama_cpp import Llama, LogitsProcessorList
4+
from pydantic import BaseModel, constr
5+
6+
from outlines.generate.processors import JSONLogitsProcessor
7+
from outlines.models.llamacpp import LlamaCppTokenizer
8+
9+
10+
class Weapon(str, Enum):
11+
sword = "sword"
12+
axe = "axe"
13+
mace = "mace"
14+
spear = "spear"
15+
bow = "bow"
16+
crossbow = "crossbow"
17+
18+
19+
class Armor(str, Enum):
20+
leather = "leather"
21+
chainmail = "chainmail"
22+
plate = "plate"
23+
24+
25+
class Character(BaseModel):
26+
name: constr(max_length=10)
27+
age: int
28+
armor: Armor
29+
weapon: Weapon
30+
strength: int
31+
32+
33+
if __name__ == "__main__":
34+
llama = Llama("./phi-2.Q4_K_M.gguf")
35+
tokenizer = LlamaCppTokenizer(llama)
36+
37+
prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"
38+
39+
logits_processor = JSONLogitsProcessor(Character, tokenizer)
40+
41+
json_str = llama.create_completion(
42+
prompt,
43+
top_k=40,
44+
top_p=0.95,
45+
temperature=0.7,
46+
max_tokens=100,
47+
logits_processor=LogitsProcessorList([logits_processor]),
48+
)["choices"][0]["text"]
49+
50+
print(json_str)

outlines/fsm/fsm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def copy(self) -> "StopAtEosFSM":
9191
class RegexFSM(FSM):
9292
"""FSM to generate text that is in the language of a regular expression."""
9393

94-
def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
94+
def __init__(self, regex_string: str, tokenizer):
9595
@cache()
9696
def create_states_mapping(
9797
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
@@ -190,7 +190,7 @@ def copy(self) -> "RegexFSM":
190190
class CFGFSM(FSM):
191191
"""FSM to generate text that is in the language of a context-free grammar."""
192192

193-
def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
193+
def __init__(self, cfg_string: str, tokenizer):
194194
self.cfg_string = cfg_string
195195
self.tokenizer = tokenizer
196196

outlines/fsm/json_schema.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import inspect
22
import json
33
import re
4-
from typing import Callable, Optional, Union
4+
from typing import Callable, Optional
55

66
from jsonschema.protocols import Validator
7-
from pydantic import BaseModel, create_model
7+
from pydantic import create_model
88
from referencing import Registry, Resource
99
from referencing._core import Resolver
1010
from referencing.jsonschema import DRAFT202012
@@ -38,9 +38,7 @@
3838
}
3939

4040

41-
def build_regex_from_object(
42-
object: Union[str, Callable, BaseModel], whitespace_pattern: Optional[str] = None
43-
):
41+
def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None):
4442
"""Turn a JSON schema into a regex that matches any JSON object that follows
4543
this schema.
4644
@@ -72,13 +70,7 @@ def build_regex_from_object(
7270
7371
"""
7472

75-
if isinstance(object, type(BaseModel)):
76-
schema = object.model_json_schema()
77-
elif callable(object):
78-
schema = get_schema_from_signature(object)
79-
else:
80-
schema = json.loads(object)
81-
73+
schema = json.loads(schema)
8274
Validator.check_schema(schema)
8375

8476
# Build reference resolver

outlines/generate/cfg.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from outlines.fsm.fsm import CFGFSM
44
from outlines.generate.api import SequenceGenerator
55
from outlines.models import OpenAI
6+
from outlines.models.llamacpp import CFGLogitsProcessor, LlamaCpp
67
from outlines.samplers import Sampler, multinomial
78

89

@@ -31,6 +32,24 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
3132
return generator
3233

3334

35+
@cfg.register(LlamaCpp)
36+
def cfg_llamacpp(
37+
model: LlamaCpp,
38+
cfg_str: str,
39+
sampler: Sampler = multinomial(),
40+
):
41+
if not isinstance(sampler, multinomial):
42+
raise NotImplementedError(
43+
r"The llama.cpp integration does not currently support any other sampling algorithm "
44+
+ "than the multinomial sampler."
45+
)
46+
47+
logits_processor = CFGLogitsProcessor(cfg_str, model.tokenizer)
48+
model.logits_processor = logits_processor
49+
50+
return model
51+
52+
3453
@cfg.register(OpenAI)
3554
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
3655
raise NotImplementedError(

outlines/generate/choice.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ def choice(
1313
model, choices: List[str], sampler: Sampler = multinomial()
1414
) -> SequenceGenerator:
1515
regex_str = r"(" + r"|".join(choices) + r")"
16-
return regex(model, regex_str, sampler)
16+
17+
generator = regex(model, regex_str, sampler)
18+
generator.format_sequence = lambda x: x
19+
20+
return generator
1721

1822

1923
@choice.register(OpenAI)

outlines/generate/json.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel
66

7-
from outlines.fsm.json_schema import build_regex_from_object, get_schema_from_signature
7+
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
88
from outlines.generate.api import SequenceGenerator
99
from outlines.models import OpenAI
1010
from outlines.samplers import Sampler, multinomial
@@ -45,17 +45,17 @@ def json(
4545
"""
4646
if isinstance(schema_object, type(BaseModel)):
4747
schema = pyjson.dumps(schema_object.model_json_schema())
48-
regex_str = build_regex_from_object(schema, whitespace_pattern)
48+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
4949
generator = regex(model, regex_str, sampler)
5050
generator.format_sequence = lambda x: schema_object.parse_raw(x)
5151
elif callable(schema_object):
5252
schema = pyjson.dumps(get_schema_from_signature(schema_object))
53-
regex_str = build_regex_from_object(schema, whitespace_pattern)
53+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
5454
generator = regex(model, regex_str, sampler)
5555
generator.format_sequence = lambda x: pyjson.loads(x)
5656
elif isinstance(schema_object, str):
5757
schema = schema_object
58-
regex_str = build_regex_from_object(schema, whitespace_pattern)
58+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
5959
generator = regex(model, regex_str, sampler)
6060
generator.format_sequence = lambda x: pyjson.loads(x)
6161
else:

outlines/generate/regex.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from outlines.fsm.fsm import RegexFSM
44
from outlines.generate.api import SequenceGenerator
55
from outlines.models import OpenAI
6+
from outlines.models.llamacpp import LlamaCpp, RegexLogitsProcessor
67
from outlines.samplers import Sampler, multinomial
78

89

@@ -35,8 +36,30 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
3536
return generator
3637

3738

39+
@regex.register(LlamaCpp)
40+
def regex_llamacpp(
41+
model: LlamaCpp,
42+
regex_str: str,
43+
sampler: Sampler = multinomial(),
44+
):
45+
if not isinstance(sampler, multinomial):
46+
raise NotImplementedError(
47+
r"The llama.cpp integration does not currently support any other sampling algorithm "
48+
+ "than the multinomial sampler."
49+
)
50+
51+
logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
52+
model.logits_processor = logits_processor
53+
54+
return model
55+
56+
3857
@regex.register(OpenAI)
39-
def regex_openai(model, regex_str: str, sampler: Sampler = multinomial()):
58+
def regex_openai(
59+
model: OpenAI,
60+
regex_str: str,
61+
sampler: Sampler = multinomial(),
62+
):
4063
raise NotImplementedError(
4164
"Cannot use regex-structured generation with an OpenAI model"
4265
+ "due to the limitations of the OpenAI API."

0 commit comments

Comments
 (0)