Skip to content

Commit d23807b

Browse files
committed
Add LlamaSequenceGenerator
We currently store the logits processor in the `LlamaCpp` instance. This causes issues when doing successive generations with different generators. In this PR we create a new `LlamaSequenceGenerator` instance every time we create a new generator, and store the logits processor in this instance which solves the issue. Fixes #700.
1 parent 10871cf commit d23807b

File tree

5 files changed

+31
-48
lines changed

5 files changed

+31
-48
lines changed

outlines/generate/cfg.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from outlines.fsm.fsm import CFGFSM
44
from outlines.generate.api import SequenceGenerator
55
from outlines.models import OpenAI
6-
from outlines.models.llamacpp import CFGLogitsProcessor, LlamaCpp
6+
from outlines.models.llamacpp import (
7+
CFGLogitsProcessor,
8+
LlamaCpp,
9+
LlamaSequenceGenerator,
10+
)
711
from outlines.samplers import Sampler, multinomial
812

913

@@ -45,9 +49,9 @@ def cfg_llamacpp(
4549
)
4650

4751
logits_processor = CFGLogitsProcessor(cfg_str, model.tokenizer)
48-
model.logits_processor = logits_processor
52+
generator = LlamaSequenceGenerator(logits_processor, model)
4953

50-
return model
54+
return generator
5155

5256

5357
@cfg.register(OpenAI)

outlines/generate/regex.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from outlines.fsm.fsm import RegexFSM
44
from outlines.generate.api import SequenceGenerator
55
from outlines.models import OpenAI
6-
from outlines.models.llamacpp import LlamaCpp, RegexLogitsProcessor
6+
from outlines.models.llamacpp import (
7+
LlamaCpp,
8+
LlamaSequenceGenerator,
9+
RegexLogitsProcessor,
10+
)
711
from outlines.samplers import Sampler, multinomial
812

913

@@ -49,9 +53,9 @@ def regex_llamacpp(
4953
)
5054

5155
logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
52-
model.logits_processor = logits_processor
56+
generator = LlamaSequenceGenerator(logits_processor, model)
5357

54-
return model
58+
return generator
5559

5660

5761
@regex.register(OpenAI)

outlines/generate/text.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from outlines.fsm.fsm import StopAtEosFSM
44
from outlines.generate import SequenceGenerator
55
from outlines.models import LlamaCpp, OpenAI
6+
from outlines.models.llamacpp import LlamaSequenceGenerator
67
from outlines.samplers import Sampler, multinomial
78

89

@@ -44,7 +45,9 @@ def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()):
4445
+ "than the multinomial sampler."
4546
)
4647

47-
return model
48+
generator = LlamaSequenceGenerator(None, model)
49+
50+
return generator
4851

4952

5053
@text.register(OpenAI)

outlines/models/llamacpp.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,12 @@
88
from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM
99

1010

11-
class LlamaCpp:
12-
"""Represents a `llama_cpp` model."""
13-
11+
class LlamaSequenceGenerator:
1412
def __init__(
15-
self, model_path, logits_processor: Optional["LogitsProcessor"] = None, **kwargs
13+
self, logits_processor: Optional["LogitsProcessor"], model: "LlamaCpp"
1614
):
17-
from llama_cpp import Llama
18-
15+
self.model = model.model
1916
self.logits_processor = logits_processor
20-
self.model = Llama(model_path, **kwargs)
21-
self.tokenizer = LlamaCppTokenizer(self)
2217

2318
def __call__(
2419
self,
@@ -89,6 +84,16 @@ def stream(
8984
)
9085

9186

87+
class LlamaCpp:
88+
"""Represents a `llama_cpp` model."""
89+
90+
def __init__(self, model_path, **kwargs):
91+
from llama_cpp import Llama
92+
93+
self.model = Llama(model_path, **kwargs)
94+
self.tokenizer = LlamaCppTokenizer(self)
95+
96+
9297
class LlamaCppTokenizer:
9398
def __init__(self, model, **kwargs):
9499
self.eos_token_id = model.model.token_eos()

tests/models/test_llama_cpp.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)