Skip to content

Commit d938678

Browse files
authored
Fix vLLM integration (#711)
When integrating Outlines with vLLM I faced the following issues, which are fixed in this PR: 1. When calling `vllm.LLM.generate` then within the internals of vLLM a `copy.deepcopy` of the vLLM `SamplingParams` is made, which includes the logits processor from Outlines (`RegexLogitsProcessor`, say). This requires everything to be pickleable, and the `RegexLogitsProcessor.fsm.vocabulary` is a `dict_values` object, which doesn't satisfy that. The fix is easy: just convert it to a list. This doesn't affect how this `vocabulary` variable is being used in the code. 2. The `RegexLogitsProcessor` takes an `llm` argument, which the docstring states should be a `vllm.LLM` object, but then attempts to extract the underlying tokenizer via `llm.tokenizer.tokenizer`. The tokenizer of `vllm.LLM` currently lies in the `vllm.LLM.llm_engine.tokenizer.tokenizer` attribute, but this is a big mess and isn't backwards compatible with previous vLLM versions. Instead, they have a convenience method, `vllm.LLM.get_tokenizer`, which fetches the tokenizer. To remain backwards compatibility, in case people have supplied `vllm.LLM.llm_engine` directly into `RegexLogitsProcessor`, it falls back to a `tokenizer` or `tokenizer.tokenizer` attribute. I also updated the vLLM example script, as that was outdated as well (used the previous `_patched_apply_logits_processors`). Closes #704
1 parent d85e67f commit d938678

File tree

5 files changed

+19
-10
lines changed

5 files changed

+19
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ docs/build
55
.coverage
66
.idea/
77
*.gguf
8+
.venv

examples/vllm_integration.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
import vllm
2-
import vllm.model_executor.layers.sampler as sampler
32
from pydantic import BaseModel
43

5-
from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors
6-
7-
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
8-
sampler._apply_logits_processors = _patched_apply_logits_processors
4+
from outlines.serve.vllm import JSONLogitsProcessor
95

106

117
class User(BaseModel):
128
id: int
139
name: str
1410

1511

16-
llm = vllm.LLM(model="gpt2")
17-
logits_processor = JSONLogitsProcessor(User, llm)
12+
llm = vllm.LLM(model="openai-community/gpt2")
13+
logits_processor = JSONLogitsProcessor(schema=User, llm=llm)
1814
result = llm.generate(
1915
["A prompt", "Another prompt"],
2016
sampling_params=vllm.SamplingParams(

outlines/fsm/fsm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def create_states_mapping(
121121
self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
122122
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
123123
)
124-
self.vocabulary = tokenizer.vocabulary.values()
124+
self.vocabulary = list(tokenizer.vocabulary.values())
125125
self.eos_token_id = tokenizer.eos_token_id
126126

127127
def allowed_token_ids(self, state: FSMState) -> List[int]:
@@ -218,7 +218,7 @@ def create_states_mapping_from_interegular_fsm(
218218
) = create_states_mapping_from_interegular_fsm(
219219
interegular_fsm, tuple(sorted(tokenizer.vocabulary.items()))
220220
)
221-
from_interegular_instance.vocabulary = tokenizer.vocabulary.values()
221+
from_interegular_instance.vocabulary = list(tokenizer.vocabulary.values())
222222
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
223223
return from_interegular_instance
224224

outlines/serve/vllm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,19 @@ def __init__(self, regex_string, llm):
4747
An instance of `vllm.LLM`
4848
4949
"""
50-
tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer)
50+
if hasattr(llm, "get_tokenizer"):
51+
tokenizer = llm.get_tokenizer()
52+
elif hasattr(llm, "tokenizer"):
53+
if hasattr(llm.tokenizer, "tokenizer"):
54+
tokenizer = llm.tokenizer.tokenizer
55+
else:
56+
tokenizer = llm.tokenizer
57+
else:
58+
raise ValueError(
59+
"The provided LLM instance in `RegexLogitsProcessor` neither has a "
60+
"`tokenizer` attribute or a `get_tokenizer` method."
61+
)
62+
tokenizer = self.adapt_tokenizer(tokenizer=tokenizer)
5163

5264
fsm = RegexFSM(regex_string, tokenizer)
5365
self.fsm = fsm

0 commit comments

Comments
 (0)