Skip to content

Commit 5f6166a

Browse files
tscholakrlouf
authored andcommitted
Update the documentation
1 parent 0032c65 commit 5f6166a

File tree

2 files changed

+32
-212
lines changed

2 files changed

+32
-212
lines changed

docs/reference/vllm.md

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,50 @@ Outlines can be deployed as an LLM service using the vLLM inference engine and a
66
pip install outlines[serve]
77
```
88

9+
Note: only vLLM v0.2.6 with ray 2.9.0 is supported at the moment.
10+
911
You can then start the server with:
1012

11-
```python
13+
```bash
1214
python -m outlines.serve.serve
1315
```
1416

15-
This will by default start a server at `http://127.0.0.1:8000` (check what the console says, though) with the OPT-125M model. If you want to specify another model:
17+
This will by default start a server at `http://127.0.0.1:8000` (check what the console says, though) with the OPT-125M model. If you want to specify another model (e.g. Mistral-7B-Instruct-v0.2), you can do so with the `--model` parameter:
1618

17-
```python
18-
python -m outlines.serve.serve --model="mistralai/Mistral-7B-v0.1"
19+
```bash
20+
python -m outlines.serve.serve --model="mistralai/Mistral-7B-Instruct-v0.2"
1921
```
2022

21-
You can then query the model in shell by passing a prompt and a [JSON Schema][jsonschema]{:target="_blank"} specification for the structure of the output:
23+
You can then query the model in shell by passing a prompt and either
24+
25+
1. a [JSON Schema][jsonschema]{:target="_blank"} specification or
26+
2. a [Regex][regex]{:target="_blank"} pattern
27+
28+
with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
29+
30+
For example, to generate a string that matches the schema `{"type": "string"}` (any string):
2231

2332
```bash
24-
curl http://0.0.0.1:8000 \
33+
curl http://127.0.0.1:8000/generate \
2534
-d '{
2635
"prompt": "What is the capital of France?",
2736
"schema": {"type": "string"}
2837
}'
2938
```
3039

31-
Or use the [requests][requests]{:target="_blank"} library from another python program. You can read the [vLLM documentation][vllm]{:target="_blank"} for more details.
40+
To generate a string that matches the regex `(-)?(0|[1-9][0-9]*)(\.[0-9]+)?([eE][+-][0-9]+)?` (a number):
41+
42+
```bash
43+
curl http://127.0.0.1:8000/generate \
44+
-d '{
45+
"prompt": "What is Pi? Give me the first 15 digits: ",
46+
"regex": "(-)?(0|[1-9][0-9]*)(\\.[0-9]+)?([eE][+-][0-9]+)?"
47+
}'
48+
```
49+
50+
Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program.
51+
52+
Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters.
3253

3354
You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs.
3455

examples/vllm_integration.py

Lines changed: 4 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -1,212 +1,11 @@
1-
import math
2-
from collections import defaultdict
3-
from typing import List, Optional
4-
5-
import torch
6-
import torch.nn as nn
71
import vllm
2+
import vllm.model_executor.layers.sampler as sampler
83
from pydantic import BaseModel
9-
from vllm.model_executor.layers.sampler import (
10-
_SAMPLING_EPS,
11-
_apply_min_p,
12-
_apply_penalties,
13-
_apply_top_p_top_k,
14-
_build_sampler_output,
15-
_get_logits,
16-
_get_logprobs,
17-
_get_penalties,
18-
_get_temperatures,
19-
_get_top_p_top_k_min_p,
20-
_prune_hidden_states,
21-
_sample,
22-
)
23-
24-
from outlines.fsm.fsm import RegexFSM
25-
from outlines.fsm.json_schema import build_regex_from_object
26-
27-
28-
def _patched_apply_logits_processors(
29-
logits,
30-
sampling_metadata,
31-
):
32-
logits_row_idx = 0
33-
found_logits_processors = False
34-
for seq_ids, sampling_params in sampling_metadata.seq_groups:
35-
logits_processors = sampling_params.logits_processors
36-
if logits_processors:
37-
found_logits_processors = True
38-
for seq_id in seq_ids:
39-
logits_row = logits[logits_row_idx]
40-
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
41-
for logits_processor in logits_processors:
42-
logits_row = logits_processor(seq_id, token_ids, logits_row)
43-
logits[logits_row_idx] = logits_row
44-
logits_row_idx += 1
45-
else:
46-
logits_row_idx += len(seq_ids)
47-
if found_logits_processors:
48-
assert logits_row_idx == logits.shape[0]
49-
return logits
50-
51-
52-
class JSONLogitsProcessor:
53-
def __init__(self, pydantic_model, llm):
54-
"""Compile the FSM that drives the JSON-guided generation.
55-
56-
Parameters
57-
----------
58-
pydantic_model
59-
A Pydantic `BaseModel` that encodes the structure we want
60-
the model to generate.
61-
llm
62-
An instance of `vllm.LLM`
63-
64-
"""
65-
schema = pydantic_model.schema_json()
66-
regex_str = build_regex_from_object(schema)
67-
tokenizer = self.adapt_tokenizer(llm.get_tokenizer())
68-
69-
fsm = RegexFSM(regex_str, tokenizer)
70-
self.fsm = fsm
71-
72-
def __call__(
73-
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
74-
) -> torch.Tensor:
75-
"""Use the FSM to bias the logits before sampling the next token."""
76-
77-
if len(input_ids) == 0: # Initialize the fsm states
78-
self.fsm_state = defaultdict(int)
79-
else:
80-
last_token = input_ids[-1]
81-
self.fsm_state[seq_id] = self.fsm.next_state(
82-
self.fsm_state[seq_id], last_token
83-
)
84-
85-
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
86-
87-
mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
88-
mask[allowed_tokens] = 0
89-
biased_scores = scores + mask
90-
91-
return biased_scores
92-
93-
def adapt_tokenizer(self, tokenizer):
94-
"""Adapt vLLM's tokenizer to use to compile the FSM.
95-
96-
The API of Outlines tokenizers is slightly different to that of
97-
`transformers`. In addition we need to handle the missing spaces to
98-
Llama's tokenizer to be able to compile FSMs for this model.
99-
100-
"""
101-
tokenizer.vocabulary = tokenizer.get_vocab()
102-
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
103-
104-
def convert_token_to_string(token: str) -> str:
105-
from transformers.file_utils import SPIECE_UNDERLINE
106-
107-
string = tokenizer.convert_tokens_to_string([token])
108-
109-
# A hack to handle missing spaces to HF's Llama tokenizers
110-
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
111-
return " " + string
112-
113-
return string
114-
115-
tokenizer.convert_token_to_string = convert_token_to_string
116-
117-
return tokenizer
118-
119-
120-
class Sampler(nn.Module):
121-
"""Samples the next tokens from the model's outputs.
122-
123-
This layer does the following:
124-
1. Discard the hidden states that are not used for sampling (i.e., all
125-
tokens except the final one in each prompt).
126-
2. Compute the logits for the next tokens.
127-
3. Apply presence, frequency and repetition penalties.
128-
4. Apply temperature scaling.
129-
5. Apply top-p and top-k truncation.
130-
6. Sample the next tokens.
131-
Here, each sequence group within the batch can have different sampling
132-
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
133-
"""
134-
135-
def __init__(self, vocab_size: int) -> None:
136-
super().__init__()
137-
self.vocab_size = vocab_size
138-
139-
def forward(
140-
self,
141-
embedding: torch.Tensor,
142-
hidden_states: torch.Tensor,
143-
sampling_metadata,
144-
embedding_bias: Optional[torch.Tensor] = None,
145-
):
146-
# Get the hidden states that we use for sampling.
147-
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
148-
149-
# Get the logits for the next tokens.
150-
logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size)
151-
152-
# Apply logits processors (if any).
153-
logits = _patched_apply_logits_processors(logits, sampling_metadata)
154-
# Apply presence and frequency penalties.
155-
presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
156-
sampling_metadata
157-
)
158-
assert len(presence_penalties) == logits.shape[0]
159-
assert len(frequency_penalties) == logits.shape[0]
160-
assert len(repetition_penalties) == logits.shape[0]
161-
logits = _apply_penalties(
162-
logits,
163-
sampling_metadata,
164-
presence_penalties,
165-
frequency_penalties,
166-
repetition_penalties,
167-
)
168-
169-
# Apply temperature scaling.
170-
temperatures = _get_temperatures(sampling_metadata)
171-
assert len(temperatures) == logits.shape[0]
172-
if any(t != 1.0 for t in temperatures):
173-
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
174-
# Use in-place division to avoid creating a new tensor.
175-
logits.div_(t.unsqueeze(dim=1))
176-
177-
# Apply top-p and top-k truncation.
178-
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
179-
sampling_metadata, self.vocab_size
180-
)
181-
assert len(top_ps) == len(top_ks) == logits.shape[0]
182-
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
183-
do_top_k = any(k != self.vocab_size for k in top_ks)
184-
if do_top_p or do_top_k:
185-
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
186-
187-
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
188-
if do_min_p:
189-
logits = _apply_min_p(logits, min_ps)
190-
191-
# We use float32 for probabilities and log probabilities.
192-
# Compute the probabilities.
193-
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
194-
# Compute the log probabilities.
195-
# Use log_softmax to ensure numerical stability.
196-
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
197-
198-
# Sample the next tokens.
199-
sample_results = _sample(probs, logprobs, sampling_metadata)
200-
# Get the logprobs query results.
201-
prompt_logprobs, sample_logprobs = _get_logprobs(
202-
logprobs, sampling_metadata, sample_results
203-
)
204-
return _build_sampler_output(
205-
sample_results, sampling_metadata, prompt_logprobs, sample_logprobs
206-
)
2074

5+
from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors
2086

209-
vllm.model_executor.layers.sampler.Sampler = Sampler
7+
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
8+
sampler._apply_logits_processors = _patched_apply_logits_processors
2109

21110

21211
class User(BaseModel):

0 commit comments

Comments
 (0)