Skip to content

Commit 3e75078

Browse files
tscholakrlouf
authored andcommitted
Update vllm patch to v0.2.6
1 parent 298a080 commit 3e75078

File tree

4 files changed

+13
-116
lines changed

4 files changed

+13
-116
lines changed

outlines/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
)
1414

1515
# Allow nested loops, useful to run in notebooks
16-
nest_asyncio.apply()
16+
try:
17+
nest_asyncio.apply()
18+
except ValueError as e:
19+
print("Could not apply nest_asyncio:", e)
1720

1821

1922
class vectorize:

outlines/serve/serve.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
from typing import AsyncGenerator
1717

1818
import uvicorn
19-
import vllm
19+
import vllm.model_executor.layers.sampler as sampler
20+
21+
from .vllm import JSONLogitsProcessor, _patched_apply_logits_processors
22+
23+
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
24+
sampler._apply_logits_processors = _patched_apply_logits_processors
25+
2026
from fastapi import FastAPI, Request
2127
from fastapi.responses import JSONResponse, Response, StreamingResponse
2228
from vllm.engine.arg_utils import AsyncEngineArgs
2329
from vllm.engine.async_llm_engine import AsyncLLMEngine
2430
from vllm.sampling_params import SamplingParams
2531
from vllm.utils import random_uuid
2632

27-
from .vllm import JSONLogitsProcessor, PatchedSampler
28-
2933
TIMEOUT_KEEP_ALIVE = 5 # seconds.
3034
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
3135
app = FastAPI()
3236
engine = None
3337

34-
# Patch the sampler so it is compatible with `JSONLogitsProcessor`
35-
vllm.model_executor.layers.sampler.Sampler = PatchedSampler
36-
3738

3839
@app.get("/health")
3940
async def health() -> Response:

outlines/serve/vllm.py

Lines changed: 1 addition & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,9 @@
22
import json
33
import math
44
from collections import defaultdict
5-
from typing import DefaultDict, List, Optional
5+
from typing import DefaultDict, List
66

77
import torch
8-
import torch.nn as nn
9-
from vllm.model_executor.layers.sampler import (
10-
_SAMPLING_EPS,
11-
_apply_min_p,
12-
_apply_penalties,
13-
_apply_top_p_top_k,
14-
_build_sampler_output,
15-
_get_logits,
16-
_get_logprobs,
17-
_get_penalties,
18-
_get_temperatures,
19-
_get_top_p_top_k_min_p,
20-
_prune_hidden_states,
21-
_sample,
22-
)
238

249
from outlines.fsm.fsm import RegexFSM
2510
from outlines.fsm.json_schema import build_regex_from_object
@@ -54,98 +39,6 @@ def _patched_apply_logits_processors(
5439
return logits
5540

5641

57-
class PatchedSampler(nn.Module):
58-
"""This code is copied from vLLM and uses the patched logits processor.
59-
60-
Samples the next tokens from the model's outputs.
61-
62-
This layer does the following:
63-
1. Discard the hidden states that are not used for sampling (i.e., all
64-
tokens except the final one in each prompt).
65-
2. Compute the logits for the next tokens.
66-
3. Apply presence, frequency and repetition penalties.
67-
4. Apply temperature scaling.
68-
5. Apply top-p and top-k truncation.
69-
6. Sample the next tokens.
70-
71-
Here, each sequence group within the batch can have different sampling
72-
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
73-
"""
74-
75-
def __init__(self, vocab_size: int) -> None:
76-
super().__init__()
77-
self.vocab_size = vocab_size
78-
79-
def forward(
80-
self,
81-
embedding: torch.Tensor,
82-
hidden_states: torch.Tensor,
83-
sampling_metadata,
84-
embedding_bias: Optional[torch.Tensor] = None,
85-
):
86-
# Get the hidden states that we use for sampling.
87-
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
88-
89-
# Get the logits for the next tokens.
90-
logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size)
91-
92-
# Apply logits processors (if any).
93-
logits = _patched_apply_logits_processors(logits, sampling_metadata)
94-
# Apply presence and frequency penalties.
95-
presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
96-
sampling_metadata
97-
)
98-
assert len(presence_penalties) == logits.shape[0]
99-
assert len(frequency_penalties) == logits.shape[0]
100-
assert len(repetition_penalties) == logits.shape[0]
101-
logits = _apply_penalties(
102-
logits,
103-
sampling_metadata,
104-
presence_penalties,
105-
frequency_penalties,
106-
repetition_penalties,
107-
)
108-
109-
# Apply temperature scaling.
110-
temperatures = _get_temperatures(sampling_metadata)
111-
assert len(temperatures) == logits.shape[0]
112-
if any(t != 1.0 for t in temperatures):
113-
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
114-
# Use in-place division to avoid creating a new tensor.
115-
logits.div_(t.unsqueeze(dim=1))
116-
117-
# Apply top-p and top-k truncation.
118-
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
119-
sampling_metadata, self.vocab_size
120-
)
121-
assert len(top_ps) == len(top_ks) == logits.shape[0]
122-
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
123-
do_top_k = any(k != self.vocab_size for k in top_ks)
124-
if do_top_p or do_top_k:
125-
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
126-
127-
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
128-
if do_min_p:
129-
logits = _apply_min_p(logits, min_ps)
130-
131-
# We use float32 for probabilities and log probabilities.
132-
# Compute the probabilities.
133-
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
134-
# Compute the log probabilities.
135-
# Use log_softmax to ensure numerical stability.
136-
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
137-
138-
# Sample the next tokens.
139-
sample_results = _sample(probs, logprobs, sampling_metadata)
140-
# Get the logprobs query results.
141-
prompt_logprobs, sample_logprobs = _get_logprobs(
142-
logprobs, sampling_metadata, sample_results
143-
)
144-
return _build_sampler_output(
145-
sample_results, sampling_metadata, prompt_logprobs, sample_logprobs
146-
)
147-
148-
14942
class JSONLogitsProcessor:
15043
def __init__(self, schema, llm):
15144
"""Compile the FSM that drives the JSON-guided generation.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ test = [
5353
"datasets",
5454
"responses",
5555
]
56-
serve = ["vllm==0.2.5"]
56+
serve = ["vllm==0.2.6"]
5757

5858
[project.urls]
5959
homepage = "https://github.com/outlines-dev/outlines"

0 commit comments

Comments
 (0)