|
1 |
| -import math |
2 |
| -from collections import defaultdict |
3 |
| -from typing import List, Optional |
4 |
| - |
5 |
| -import torch |
6 |
| -import torch.nn as nn |
7 | 1 | import vllm
|
| 2 | +import vllm.model_executor.layers.sampler as sampler |
8 | 3 | from pydantic import BaseModel
|
9 |
| -from vllm.model_executor.layers.sampler import ( |
10 |
| - _SAMPLING_EPS, |
11 |
| - _apply_min_p, |
12 |
| - _apply_penalties, |
13 |
| - _apply_top_p_top_k, |
14 |
| - _build_sampler_output, |
15 |
| - _get_logits, |
16 |
| - _get_logprobs, |
17 |
| - _get_penalties, |
18 |
| - _get_temperatures, |
19 |
| - _get_top_p_top_k_min_p, |
20 |
| - _prune_hidden_states, |
21 |
| - _sample, |
22 |
| -) |
23 |
| - |
24 |
| -from outlines.fsm.fsm import RegexFSM |
25 |
| -from outlines.fsm.json_schema import build_regex_from_object |
26 |
| - |
27 |
| - |
28 |
| -def _patched_apply_logits_processors( |
29 |
| - logits, |
30 |
| - sampling_metadata, |
31 |
| -): |
32 |
| - logits_row_idx = 0 |
33 |
| - found_logits_processors = False |
34 |
| - for seq_ids, sampling_params in sampling_metadata.seq_groups: |
35 |
| - logits_processors = sampling_params.logits_processors |
36 |
| - if logits_processors: |
37 |
| - found_logits_processors = True |
38 |
| - for seq_id in seq_ids: |
39 |
| - logits_row = logits[logits_row_idx] |
40 |
| - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids |
41 |
| - for logits_processor in logits_processors: |
42 |
| - logits_row = logits_processor(seq_id, token_ids, logits_row) |
43 |
| - logits[logits_row_idx] = logits_row |
44 |
| - logits_row_idx += 1 |
45 |
| - else: |
46 |
| - logits_row_idx += len(seq_ids) |
47 |
| - if found_logits_processors: |
48 |
| - assert logits_row_idx == logits.shape[0] |
49 |
| - return logits |
50 |
| - |
51 |
| - |
52 |
| -class JSONLogitsProcessor: |
53 |
| - def __init__(self, pydantic_model, llm): |
54 |
| - """Compile the FSM that drives the JSON-guided generation. |
55 |
| -
|
56 |
| - Parameters |
57 |
| - ---------- |
58 |
| - pydantic_model |
59 |
| - A Pydantic `BaseModel` that encodes the structure we want |
60 |
| - the model to generate. |
61 |
| - llm |
62 |
| - An instance of `vllm.LLM` |
63 |
| -
|
64 |
| - """ |
65 |
| - schema = pydantic_model.schema_json() |
66 |
| - regex_str = build_regex_from_object(schema) |
67 |
| - tokenizer = self.adapt_tokenizer(llm.get_tokenizer()) |
68 |
| - |
69 |
| - fsm = RegexFSM(regex_str, tokenizer) |
70 |
| - self.fsm = fsm |
71 |
| - |
72 |
| - def __call__( |
73 |
| - self, seq_id: int, input_ids: List[int], scores: torch.Tensor |
74 |
| - ) -> torch.Tensor: |
75 |
| - """Use the FSM to bias the logits before sampling the next token.""" |
76 |
| - |
77 |
| - if len(input_ids) == 0: # Initialize the fsm states |
78 |
| - self.fsm_state = defaultdict(int) |
79 |
| - else: |
80 |
| - last_token = input_ids[-1] |
81 |
| - self.fsm_state[seq_id] = self.fsm.next_state( |
82 |
| - self.fsm_state[seq_id], last_token |
83 |
| - ) |
84 |
| - |
85 |
| - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) |
86 |
| - |
87 |
| - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) |
88 |
| - mask[allowed_tokens] = 0 |
89 |
| - biased_scores = scores + mask |
90 |
| - |
91 |
| - return biased_scores |
92 |
| - |
93 |
| - def adapt_tokenizer(self, tokenizer): |
94 |
| - """Adapt vLLM's tokenizer to use to compile the FSM. |
95 |
| -
|
96 |
| - The API of Outlines tokenizers is slightly different to that of |
97 |
| - `transformers`. In addition we need to handle the missing spaces to |
98 |
| - Llama's tokenizer to be able to compile FSMs for this model. |
99 |
| -
|
100 |
| - """ |
101 |
| - tokenizer.vocabulary = tokenizer.get_vocab() |
102 |
| - tokenizer.special_tokens = set(tokenizer.all_special_tokens) |
103 |
| - |
104 |
| - def convert_token_to_string(token: str) -> str: |
105 |
| - from transformers.file_utils import SPIECE_UNDERLINE |
106 |
| - |
107 |
| - string = tokenizer.convert_tokens_to_string([token]) |
108 |
| - |
109 |
| - # A hack to handle missing spaces to HF's Llama tokenizers |
110 |
| - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": |
111 |
| - return " " + string |
112 |
| - |
113 |
| - return string |
114 |
| - |
115 |
| - tokenizer.convert_token_to_string = convert_token_to_string |
116 |
| - |
117 |
| - return tokenizer |
118 |
| - |
119 |
| - |
120 |
| -class Sampler(nn.Module): |
121 |
| - """Samples the next tokens from the model's outputs. |
122 |
| -
|
123 |
| - This layer does the following: |
124 |
| - 1. Discard the hidden states that are not used for sampling (i.e., all |
125 |
| - tokens except the final one in each prompt). |
126 |
| - 2. Compute the logits for the next tokens. |
127 |
| - 3. Apply presence, frequency and repetition penalties. |
128 |
| - 4. Apply temperature scaling. |
129 |
| - 5. Apply top-p and top-k truncation. |
130 |
| - 6. Sample the next tokens. |
131 |
| - Here, each sequence group within the batch can have different sampling |
132 |
| - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). |
133 |
| - """ |
134 |
| - |
135 |
| - def __init__(self, vocab_size: int) -> None: |
136 |
| - super().__init__() |
137 |
| - self.vocab_size = vocab_size |
138 |
| - |
139 |
| - def forward( |
140 |
| - self, |
141 |
| - embedding: torch.Tensor, |
142 |
| - hidden_states: torch.Tensor, |
143 |
| - sampling_metadata, |
144 |
| - embedding_bias: Optional[torch.Tensor] = None, |
145 |
| - ): |
146 |
| - # Get the hidden states that we use for sampling. |
147 |
| - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
148 |
| - |
149 |
| - # Get the logits for the next tokens. |
150 |
| - logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) |
151 |
| - |
152 |
| - # Apply logits processors (if any). |
153 |
| - logits = _patched_apply_logits_processors(logits, sampling_metadata) |
154 |
| - # Apply presence and frequency penalties. |
155 |
| - presence_penalties, frequency_penalties, repetition_penalties = _get_penalties( |
156 |
| - sampling_metadata |
157 |
| - ) |
158 |
| - assert len(presence_penalties) == logits.shape[0] |
159 |
| - assert len(frequency_penalties) == logits.shape[0] |
160 |
| - assert len(repetition_penalties) == logits.shape[0] |
161 |
| - logits = _apply_penalties( |
162 |
| - logits, |
163 |
| - sampling_metadata, |
164 |
| - presence_penalties, |
165 |
| - frequency_penalties, |
166 |
| - repetition_penalties, |
167 |
| - ) |
168 |
| - |
169 |
| - # Apply temperature scaling. |
170 |
| - temperatures = _get_temperatures(sampling_metadata) |
171 |
| - assert len(temperatures) == logits.shape[0] |
172 |
| - if any(t != 1.0 for t in temperatures): |
173 |
| - t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) |
174 |
| - # Use in-place division to avoid creating a new tensor. |
175 |
| - logits.div_(t.unsqueeze(dim=1)) |
176 |
| - |
177 |
| - # Apply top-p and top-k truncation. |
178 |
| - top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( |
179 |
| - sampling_metadata, self.vocab_size |
180 |
| - ) |
181 |
| - assert len(top_ps) == len(top_ks) == logits.shape[0] |
182 |
| - do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) |
183 |
| - do_top_k = any(k != self.vocab_size for k in top_ks) |
184 |
| - if do_top_p or do_top_k: |
185 |
| - logits = _apply_top_p_top_k(logits, top_ps, top_ks) |
186 |
| - |
187 |
| - do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) |
188 |
| - if do_min_p: |
189 |
| - logits = _apply_min_p(logits, min_ps) |
190 |
| - |
191 |
| - # We use float32 for probabilities and log probabilities. |
192 |
| - # Compute the probabilities. |
193 |
| - probs = torch.softmax(logits, dim=-1, dtype=torch.float) |
194 |
| - # Compute the log probabilities. |
195 |
| - # Use log_softmax to ensure numerical stability. |
196 |
| - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) |
197 |
| - |
198 |
| - # Sample the next tokens. |
199 |
| - sample_results = _sample(probs, logprobs, sampling_metadata) |
200 |
| - # Get the logprobs query results. |
201 |
| - prompt_logprobs, sample_logprobs = _get_logprobs( |
202 |
| - logprobs, sampling_metadata, sample_results |
203 |
| - ) |
204 |
| - return _build_sampler_output( |
205 |
| - sample_results, sampling_metadata, prompt_logprobs, sample_logprobs |
206 |
| - ) |
207 | 4 |
|
| 5 | +from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors |
208 | 6 |
|
209 |
| -vllm.model_executor.layers.sampler.Sampler = Sampler |
| 7 | +# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor` |
| 8 | +sampler._apply_logits_processors = _patched_apply_logits_processors |
210 | 9 |
|
211 | 10 |
|
212 | 11 | class User(BaseModel):
|
|
0 commit comments