Skip to content

Commit 0032c65

Browse files
tscholakrlouf
authored andcommitted
Add regex support to vLLM endpoint
1 parent 3e75078 commit 0032c65

File tree

3 files changed

+50
-21
lines changed

3 files changed

+50
-21
lines changed

outlines/serve/serve.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,23 @@
1717

1818
import uvicorn
1919
import vllm.model_executor.layers.sampler as sampler
20-
21-
from .vllm import JSONLogitsProcessor, _patched_apply_logits_processors
22-
23-
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
24-
sampler._apply_logits_processors = _patched_apply_logits_processors
25-
2620
from fastapi import FastAPI, Request
2721
from fastapi.responses import JSONResponse, Response, StreamingResponse
2822
from vllm.engine.arg_utils import AsyncEngineArgs
2923
from vllm.engine.async_llm_engine import AsyncLLMEngine
3024
from vllm.sampling_params import SamplingParams
3125
from vllm.utils import random_uuid
3226

27+
from .vllm import (
28+
JSONLogitsProcessor,
29+
RegexLogitsProcessor,
30+
_patched_apply_logits_processors,
31+
)
32+
33+
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
34+
sampler._apply_logits_processors = _patched_apply_logits_processors
35+
36+
3337
TIMEOUT_KEEP_ALIVE = 5 # seconds.
3438
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
3539
app = FastAPI()
@@ -48,22 +52,28 @@ async def generate(request: Request) -> Response:
4852
4953
The request should be a JSON object with the following fields:
5054
- prompt: the prompt to use for the generation.
51-
- schema: the JSON schema to use for the generation
55+
- schema: the JSON schema to use for the generation (if regex is not provided).
56+
- regex: the regex to use for the generation (if schema is not provided).
5257
- stream: whether to stream the results or not.
5358
- other fields: the sampling parameters (See `SamplingParams` for details).
5459
"""
60+
assert engine is not None
61+
5562
request_dict = await request.json()
5663
prompt = request_dict.pop("prompt")
5764
stream = request_dict.pop("stream", False)
5865

5966
json_schema = request_dict.pop("schema", None)
67+
regex_string = request_dict.pop("regex", None)
6068
if json_schema is not None:
61-
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] # type: ignore
69+
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)]
70+
elif regex_string is not None:
71+
logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)]
6272
else:
6373
logits_processors = []
6474

6575
sampling_params = SamplingParams(
66-
**request_dict, logits_processors=logits_processors
76+
**request_dict, logits_processors=logits_processors # type: ignore
6777
)
6878
request_id = random_uuid()
6979

@@ -107,7 +117,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
107117

108118
# Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len`
109119
# arguments
110-
engine_args = AsyncEngineArgs.from_cli_args(args)
120+
engine_args: AsyncEngineArgs = AsyncEngineArgs.from_cli_args(args) # type: ignore
111121

112122
# Sets default for the model (`facebook/opt-125m`)
113123
engine = AsyncLLMEngine.from_engine_args(engine_args)

outlines/serve/vllm.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,21 @@ def _patched_apply_logits_processors(
3939
return logits
4040

4141

42-
class JSONLogitsProcessor:
43-
def __init__(self, schema, llm):
44-
"""Compile the FSM that drives the JSON-guided generation.
42+
class RegexLogitsProcessor:
43+
def __init__(self, regex_string, llm):
44+
"""Compile the FSM that drives the regex-guided generation.
4545
4646
Parameters
4747
----------
48-
pydantic_model
49-
A Pydantic `BaseModel` that encodes the structure we want
50-
the model to generate.
48+
regex_string
49+
A string that represents a regular expression
5150
llm
5251
An instance of `vllm.LLM`
5352
5453
"""
55-
if isinstance(schema, dict):
56-
schema = json.dumps(schema)
57-
regex_str = build_regex_from_object(schema)
5854
tokenizer = self.adapt_tokenizer(llm.tokenizer)
5955

60-
fsm = RegexFSM(regex_str, tokenizer)
56+
fsm = RegexFSM(regex_string, tokenizer)
6157
self.fsm = fsm
6258

6359
def __call__(
@@ -106,3 +102,21 @@ def convert_token_to_string(token: str) -> str:
106102
tokenizer.convert_token_to_string = convert_token_to_string
107103

108104
return tokenizer
105+
106+
107+
class JSONLogitsProcessor(RegexLogitsProcessor):
108+
def __init__(self, schema, llm):
109+
"""Compile the FSM that drives the JSON-guided generation.
110+
111+
Parameters
112+
----------
113+
schema
114+
A JSON schema that encodes the structure we want the model to generate
115+
llm
116+
An instance of `vllm.LLM`
117+
118+
"""
119+
if isinstance(schema, dict):
120+
schema = json.dumps(schema)
121+
regex_string = build_regex_from_object(schema)
122+
super().__init__(regex_string, llm)

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ test = [
5353
"datasets",
5454
"responses",
5555
]
56-
serve = ["vllm==0.2.6"]
56+
serve = [
57+
"vllm==0.2.6",
58+
"ray==2.9.0",
59+
"uvicorn",
60+
"fastapi"
61+
]
5762

5863
[project.urls]
5964
homepage = "https://github.com/outlines-dev/outlines"

0 commit comments

Comments
 (0)