17
17
18
18
import uvicorn
19
19
import vllm .model_executor .layers .sampler as sampler
20
-
21
- from .vllm import JSONLogitsProcessor , _patched_apply_logits_processors
22
-
23
- # Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
24
- sampler ._apply_logits_processors = _patched_apply_logits_processors
25
-
26
20
from fastapi import FastAPI , Request
27
21
from fastapi .responses import JSONResponse , Response , StreamingResponse
28
22
from vllm .engine .arg_utils import AsyncEngineArgs
29
23
from vllm .engine .async_llm_engine import AsyncLLMEngine
30
24
from vllm .sampling_params import SamplingParams
31
25
from vllm .utils import random_uuid
32
26
27
+ from .vllm import (
28
+ JSONLogitsProcessor ,
29
+ RegexLogitsProcessor ,
30
+ _patched_apply_logits_processors ,
31
+ )
32
+
33
+ # Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
34
+ sampler ._apply_logits_processors = _patched_apply_logits_processors
35
+
36
+
33
37
TIMEOUT_KEEP_ALIVE = 5 # seconds.
34
38
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
35
39
app = FastAPI ()
@@ -48,22 +52,28 @@ async def generate(request: Request) -> Response:
48
52
49
53
The request should be a JSON object with the following fields:
50
54
- prompt: the prompt to use for the generation.
51
- - schema: the JSON schema to use for the generation
55
+ - schema: the JSON schema to use for the generation (if regex is not provided).
56
+ - regex: the regex to use for the generation (if schema is not provided).
52
57
- stream: whether to stream the results or not.
53
58
- other fields: the sampling parameters (See `SamplingParams` for details).
54
59
"""
60
+ assert engine is not None
61
+
55
62
request_dict = await request .json ()
56
63
prompt = request_dict .pop ("prompt" )
57
64
stream = request_dict .pop ("stream" , False )
58
65
59
66
json_schema = request_dict .pop ("schema" , None )
67
+ regex_string = request_dict .pop ("regex" , None )
60
68
if json_schema is not None :
61
- logits_processors = [JSONLogitsProcessor (json_schema , engine .engine )] # type: ignore
69
+ logits_processors = [JSONLogitsProcessor (json_schema , engine .engine )]
70
+ elif regex_string is not None :
71
+ logits_processors = [RegexLogitsProcessor (regex_string , engine .engine )]
62
72
else :
63
73
logits_processors = []
64
74
65
75
sampling_params = SamplingParams (
66
- ** request_dict , logits_processors = logits_processors
76
+ ** request_dict , logits_processors = logits_processors # type: ignore
67
77
)
68
78
request_id = random_uuid ()
69
79
@@ -107,7 +117,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
107
117
108
118
# Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len`
109
119
# arguments
110
- engine_args = AsyncEngineArgs .from_cli_args (args )
120
+ engine_args : AsyncEngineArgs = AsyncEngineArgs .from_cli_args (args ) # type: ignore
111
121
112
122
# Sets default for the model (`facebook/opt-125m`)
113
123
engine = AsyncLLMEngine .from_engine_args (engine_args )
0 commit comments