Skip to content

Commit 8eb7ac0

Browse files
committed
Serve JSON with vLLM using FastAPI and gunicorn
1 parent 67e524d commit 8eb7ac0

File tree

4 files changed

+342
-1
lines changed

4 files changed

+342
-1
lines changed

outlines/serve/__init__.py

Whitespace-only changes.

outlines/serve/serve.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2023 the vLLM developers
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import argparse
15+
import json
16+
from typing import AsyncGenerator
17+
18+
import uvicorn
19+
import vllm
20+
from fastapi import FastAPI, Request
21+
from fastapi.responses import JSONResponse, Response, StreamingResponse
22+
from vllm.engine.arg_utils import AsyncEngineArgs
23+
from vllm.engine.async_llm_engine import AsyncLLMEngine
24+
from vllm.sampling_params import SamplingParams
25+
from vllm.utils import random_uuid
26+
27+
from .vllm import JSONLogitsProcessor, PatchedSampler
28+
29+
TIMEOUT_KEEP_ALIVE = 5 # seconds.
30+
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
31+
app = FastAPI()
32+
engine = None
33+
34+
# Patch the sampler so it is compatible with `JSONLogitsProcessor`
35+
vllm.model_executor.layers.sampler.Sampler = PatchedSampler
36+
37+
38+
@app.get("/health")
39+
async def health() -> Response:
40+
"""Health check."""
41+
return Response(status_code=200)
42+
43+
44+
@app.post("/generate")
45+
async def generate(request: Request) -> Response:
46+
"""Generate completion for the request.
47+
48+
The request should be a JSON object with the following fields:
49+
- prompt: the prompt to use for the generation.
50+
- schema: the JSON schema to use for the generation
51+
- stream: whether to stream the results or not.
52+
- other fields: the sampling parameters (See `SamplingParams` for details).
53+
"""
54+
request_dict = await request.json()
55+
prompt = request_dict.pop("prompt")
56+
stream = request_dict.pop("stream", False)
57+
58+
json_schema = request_dict.pop("schema", None)
59+
if json_schema is not None:
60+
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] # type: ignore
61+
else:
62+
logits_processors = []
63+
64+
sampling_params = SamplingParams(
65+
**request_dict, logits_processors=logits_processors
66+
)
67+
request_id = random_uuid()
68+
69+
results_generator = engine.generate(prompt, sampling_params, request_id) # type: ignore
70+
71+
# Streaming case
72+
async def stream_results() -> AsyncGenerator[bytes, None]:
73+
async for request_output in results_generator:
74+
prompt = request_output.prompt
75+
text_outputs = [prompt + output.text for output in request_output.outputs]
76+
ret = {"text": text_outputs}
77+
yield (json.dumps(ret) + "\0").encode("utf-8")
78+
79+
if stream:
80+
return StreamingResponse(stream_results())
81+
82+
# Non-streaming case
83+
final_output = None
84+
async for request_output in results_generator:
85+
if await request.is_disconnected():
86+
# Abort the request if the client disconnects.
87+
await engine.abort(request_id) # type: ignore
88+
return Response(status_code=499)
89+
final_output = request_output
90+
91+
assert final_output is not None
92+
prompt = final_output.prompt
93+
text_outputs = [prompt + output.text for output in final_output.outputs]
94+
ret = {"text": text_outputs}
95+
return JSONResponse(ret)
96+
97+
98+
if __name__ == "__main__":
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument("--host", type=str, default=None)
101+
parser.add_argument("--port", type=int, default=8000)
102+
parser.add_argument("--ssl-keyfile", type=str, default=None)
103+
parser.add_argument("--ssl-certfile", type=str, default=None)
104+
parser = AsyncEngineArgs.add_cli_args(parser)
105+
args = parser.parse_args()
106+
107+
# Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len`
108+
# arguments
109+
engine_args = AsyncEngineArgs.from_cli_args(args)
110+
111+
# Sets default for the model (`facebook/opt-125m`)
112+
engine = AsyncLLMEngine.from_engine_args(engine_args)
113+
114+
uvicorn.run(
115+
app,
116+
host=args.host,
117+
port=args.port,
118+
log_level="debug",
119+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
120+
ssl_keyfile=args.ssl_keyfile,
121+
ssl_certfile=args.ssl_certfile,
122+
)

outlines/serve/vllm.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""Make vLLM compatible with Outlines' guided generation."""
2+
import json
3+
import math
4+
from collections import defaultdict
5+
from typing import DefaultDict, List, Optional
6+
7+
import torch
8+
import torch.nn as nn
9+
from vllm.model_executor.layers.sampler import (
10+
_SAMPLING_EPS,
11+
_apply_min_p,
12+
_apply_penalties,
13+
_apply_top_p_top_k,
14+
_build_sampler_output,
15+
_get_logits,
16+
_get_logprobs,
17+
_get_penalties,
18+
_get_temperatures,
19+
_get_top_p_top_k_min_p,
20+
_prune_hidden_states,
21+
_sample,
22+
)
23+
24+
from outlines.fsm.fsm import RegexFSM
25+
from outlines.fsm.json_schema import build_regex_from_object
26+
27+
28+
def _patched_apply_logits_processors(
29+
logits,
30+
sampling_metadata,
31+
):
32+
"""Patch vLLM's logit processor.
33+
34+
We need to patch the logits processor to pass the `seq_id` so we can
35+
handle several sequences in `JSONLogitsProcessor`
36+
"""
37+
logits_row_idx = 0
38+
found_logits_processors = False
39+
for seq_ids, sampling_params in sampling_metadata.seq_groups:
40+
logits_processors = sampling_params.logits_processors
41+
if logits_processors:
42+
found_logits_processors = True
43+
for seq_id in seq_ids:
44+
logits_row = logits[logits_row_idx]
45+
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
46+
for logits_processor in logits_processors:
47+
logits_row = logits_processor(seq_id, token_ids, logits_row)
48+
logits[logits_row_idx] = logits_row
49+
logits_row_idx += 1
50+
else:
51+
logits_row_idx += len(seq_ids)
52+
if found_logits_processors:
53+
assert logits_row_idx == logits.shape[0]
54+
return logits
55+
56+
57+
class PatchedSampler(nn.Module):
58+
"""This code is copied from vLLM and uses the patched logits processor.
59+
60+
Samples the next tokens from the model's outputs.
61+
62+
This layer does the following:
63+
1. Discard the hidden states that are not used for sampling (i.e., all
64+
tokens except the final one in each prompt).
65+
2. Compute the logits for the next tokens.
66+
3. Apply presence, frequency and repetition penalties.
67+
4. Apply temperature scaling.
68+
5. Apply top-p and top-k truncation.
69+
6. Sample the next tokens.
70+
71+
Here, each sequence group within the batch can have different sampling
72+
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
73+
"""
74+
75+
def __init__(self, vocab_size: int) -> None:
76+
super().__init__()
77+
self.vocab_size = vocab_size
78+
79+
def forward(
80+
self,
81+
embedding: torch.Tensor,
82+
hidden_states: torch.Tensor,
83+
sampling_metadata,
84+
embedding_bias: Optional[torch.Tensor] = None,
85+
):
86+
# Get the hidden states that we use for sampling.
87+
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
88+
89+
# Get the logits for the next tokens.
90+
logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size)
91+
92+
# Apply logits processors (if any).
93+
logits = _patched_apply_logits_processors(logits, sampling_metadata)
94+
# Apply presence and frequency penalties.
95+
presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
96+
sampling_metadata
97+
)
98+
assert len(presence_penalties) == logits.shape[0]
99+
assert len(frequency_penalties) == logits.shape[0]
100+
assert len(repetition_penalties) == logits.shape[0]
101+
logits = _apply_penalties(
102+
logits,
103+
sampling_metadata,
104+
presence_penalties,
105+
frequency_penalties,
106+
repetition_penalties,
107+
)
108+
109+
# Apply temperature scaling.
110+
temperatures = _get_temperatures(sampling_metadata)
111+
assert len(temperatures) == logits.shape[0]
112+
if any(t != 1.0 for t in temperatures):
113+
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
114+
# Use in-place division to avoid creating a new tensor.
115+
logits.div_(t.unsqueeze(dim=1))
116+
117+
# Apply top-p and top-k truncation.
118+
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
119+
sampling_metadata, self.vocab_size
120+
)
121+
assert len(top_ps) == len(top_ks) == logits.shape[0]
122+
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
123+
do_top_k = any(k != self.vocab_size for k in top_ks)
124+
if do_top_p or do_top_k:
125+
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
126+
127+
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
128+
if do_min_p:
129+
logits = _apply_min_p(logits, min_ps)
130+
131+
# We use float32 for probabilities and log probabilities.
132+
# Compute the probabilities.
133+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
134+
# Compute the log probabilities.
135+
# Use log_softmax to ensure numerical stability.
136+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
137+
138+
# Sample the next tokens.
139+
sample_results = _sample(probs, logprobs, sampling_metadata)
140+
# Get the logprobs query results.
141+
prompt_logprobs, sample_logprobs = _get_logprobs(
142+
logprobs, sampling_metadata, sample_results
143+
)
144+
return _build_sampler_output(
145+
sample_results, sampling_metadata, prompt_logprobs, sample_logprobs
146+
)
147+
148+
149+
class JSONLogitsProcessor:
150+
def __init__(self, schema, llm):
151+
"""Compile the FSM that drives the JSON-guided generation.
152+
153+
Parameters
154+
----------
155+
pydantic_model
156+
A Pydantic `BaseModel` that encodes the structure we want
157+
the model to generate.
158+
llm
159+
An instance of `vllm.LLM`
160+
161+
"""
162+
if isinstance(schema, dict):
163+
schema = json.dumps(schema)
164+
regex_str = build_regex_from_object(schema)
165+
tokenizer = self.adapt_tokenizer(llm.tokenizer)
166+
167+
fsm = RegexFSM(regex_str, tokenizer)
168+
self.fsm = fsm
169+
170+
def __call__(
171+
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
172+
) -> torch.Tensor:
173+
"""Use the FSM to bias the logits before sampling the next token."""
174+
175+
if len(input_ids) == 0: # Initialize the fsm states
176+
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
177+
else:
178+
last_token = input_ids[-1]
179+
self.fsm_state[seq_id] = self.fsm.next_state(
180+
self.fsm_state[seq_id], last_token
181+
)
182+
183+
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
184+
185+
mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
186+
mask[allowed_tokens] = 0
187+
biased_scores = scores + mask
188+
189+
return biased_scores
190+
191+
def adapt_tokenizer(self, tokenizer):
192+
"""Adapt vLLM's tokenizer to use to compile the FSM.
193+
194+
The API of Outlines tokenizers is slightly different to that of
195+
`transformers`. In addition we need to handle the missing spaces to
196+
Llama's tokenizer to be able to compile FSMs for this model.
197+
198+
"""
199+
tokenizer.vocabulary = tokenizer.get_vocab()
200+
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
201+
202+
def convert_token_to_string(token: str) -> str:
203+
from transformers.file_utils import SPIECE_UNDERLINE
204+
205+
string = tokenizer.convert_tokens_to_string([token])
206+
207+
# A hack to handle missing spaces to HF's Llama tokenizers
208+
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
209+
return " " + string
210+
211+
return string
212+
213+
tokenizer.convert_token_to_string = convert_token_to_string
214+
215+
return tokenizer

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ test = [
5353
"datasets",
5454
"responses",
5555
]
56+
serve = ["vllm"]
5657

5758
[project.urls]
5859
homepage = "https://github.com/outlines-dev/outlines"
@@ -102,14 +103,17 @@ module = [
102103
"referencing.*",
103104
"scipy.*",
104105
"tiktoken.*",
105-
"torch",
106+
"torch.*",
106107
"transformers.*",
107108
"lark.*",
108109
"interegular.*",
109110
"datasets.*",
110111
"numba.*",
111112
"requests.*",
112113
"responses.*",
114+
"vllm.*",
115+
"uvicorn.*",
116+
"fastapi.*",
113117
]
114118
ignore_missing_imports = true
115119

0 commit comments

Comments
 (0)