|
| 1 | +"""Make vLLM compatible with Outlines' guided generation.""" |
| 2 | +import json |
| 3 | +import math |
| 4 | +from collections import defaultdict |
| 5 | +from typing import DefaultDict, List, Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +from vllm.model_executor.layers.sampler import ( |
| 10 | + _SAMPLING_EPS, |
| 11 | + _apply_min_p, |
| 12 | + _apply_penalties, |
| 13 | + _apply_top_p_top_k, |
| 14 | + _build_sampler_output, |
| 15 | + _get_logits, |
| 16 | + _get_logprobs, |
| 17 | + _get_penalties, |
| 18 | + _get_temperatures, |
| 19 | + _get_top_p_top_k_min_p, |
| 20 | + _prune_hidden_states, |
| 21 | + _sample, |
| 22 | +) |
| 23 | + |
| 24 | +from outlines.fsm.fsm import RegexFSM |
| 25 | +from outlines.fsm.json_schema import build_regex_from_object |
| 26 | + |
| 27 | + |
| 28 | +def _patched_apply_logits_processors( |
| 29 | + logits, |
| 30 | + sampling_metadata, |
| 31 | +): |
| 32 | + """Patch vLLM's logit processor. |
| 33 | +
|
| 34 | + We need to patch the logits processor to pass the `seq_id` so we can |
| 35 | + handle several sequences in `JSONLogitsProcessor` |
| 36 | + """ |
| 37 | + logits_row_idx = 0 |
| 38 | + found_logits_processors = False |
| 39 | + for seq_ids, sampling_params in sampling_metadata.seq_groups: |
| 40 | + logits_processors = sampling_params.logits_processors |
| 41 | + if logits_processors: |
| 42 | + found_logits_processors = True |
| 43 | + for seq_id in seq_ids: |
| 44 | + logits_row = logits[logits_row_idx] |
| 45 | + token_ids = sampling_metadata.seq_data[seq_id].output_token_ids |
| 46 | + for logits_processor in logits_processors: |
| 47 | + logits_row = logits_processor(seq_id, token_ids, logits_row) |
| 48 | + logits[logits_row_idx] = logits_row |
| 49 | + logits_row_idx += 1 |
| 50 | + else: |
| 51 | + logits_row_idx += len(seq_ids) |
| 52 | + if found_logits_processors: |
| 53 | + assert logits_row_idx == logits.shape[0] |
| 54 | + return logits |
| 55 | + |
| 56 | + |
| 57 | +class PatchedSampler(nn.Module): |
| 58 | + """This code is copied from vLLM and uses the patched logits processor. |
| 59 | +
|
| 60 | + Samples the next tokens from the model's outputs. |
| 61 | +
|
| 62 | + This layer does the following: |
| 63 | + 1. Discard the hidden states that are not used for sampling (i.e., all |
| 64 | + tokens except the final one in each prompt). |
| 65 | + 2. Compute the logits for the next tokens. |
| 66 | + 3. Apply presence, frequency and repetition penalties. |
| 67 | + 4. Apply temperature scaling. |
| 68 | + 5. Apply top-p and top-k truncation. |
| 69 | + 6. Sample the next tokens. |
| 70 | +
|
| 71 | + Here, each sequence group within the batch can have different sampling |
| 72 | + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). |
| 73 | + """ |
| 74 | + |
| 75 | + def __init__(self, vocab_size: int) -> None: |
| 76 | + super().__init__() |
| 77 | + self.vocab_size = vocab_size |
| 78 | + |
| 79 | + def forward( |
| 80 | + self, |
| 81 | + embedding: torch.Tensor, |
| 82 | + hidden_states: torch.Tensor, |
| 83 | + sampling_metadata, |
| 84 | + embedding_bias: Optional[torch.Tensor] = None, |
| 85 | + ): |
| 86 | + # Get the hidden states that we use for sampling. |
| 87 | + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
| 88 | + |
| 89 | + # Get the logits for the next tokens. |
| 90 | + logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) |
| 91 | + |
| 92 | + # Apply logits processors (if any). |
| 93 | + logits = _patched_apply_logits_processors(logits, sampling_metadata) |
| 94 | + # Apply presence and frequency penalties. |
| 95 | + presence_penalties, frequency_penalties, repetition_penalties = _get_penalties( |
| 96 | + sampling_metadata |
| 97 | + ) |
| 98 | + assert len(presence_penalties) == logits.shape[0] |
| 99 | + assert len(frequency_penalties) == logits.shape[0] |
| 100 | + assert len(repetition_penalties) == logits.shape[0] |
| 101 | + logits = _apply_penalties( |
| 102 | + logits, |
| 103 | + sampling_metadata, |
| 104 | + presence_penalties, |
| 105 | + frequency_penalties, |
| 106 | + repetition_penalties, |
| 107 | + ) |
| 108 | + |
| 109 | + # Apply temperature scaling. |
| 110 | + temperatures = _get_temperatures(sampling_metadata) |
| 111 | + assert len(temperatures) == logits.shape[0] |
| 112 | + if any(t != 1.0 for t in temperatures): |
| 113 | + t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) |
| 114 | + # Use in-place division to avoid creating a new tensor. |
| 115 | + logits.div_(t.unsqueeze(dim=1)) |
| 116 | + |
| 117 | + # Apply top-p and top-k truncation. |
| 118 | + top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( |
| 119 | + sampling_metadata, self.vocab_size |
| 120 | + ) |
| 121 | + assert len(top_ps) == len(top_ks) == logits.shape[0] |
| 122 | + do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) |
| 123 | + do_top_k = any(k != self.vocab_size for k in top_ks) |
| 124 | + if do_top_p or do_top_k: |
| 125 | + logits = _apply_top_p_top_k(logits, top_ps, top_ks) |
| 126 | + |
| 127 | + do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) |
| 128 | + if do_min_p: |
| 129 | + logits = _apply_min_p(logits, min_ps) |
| 130 | + |
| 131 | + # We use float32 for probabilities and log probabilities. |
| 132 | + # Compute the probabilities. |
| 133 | + probs = torch.softmax(logits, dim=-1, dtype=torch.float) |
| 134 | + # Compute the log probabilities. |
| 135 | + # Use log_softmax to ensure numerical stability. |
| 136 | + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) |
| 137 | + |
| 138 | + # Sample the next tokens. |
| 139 | + sample_results = _sample(probs, logprobs, sampling_metadata) |
| 140 | + # Get the logprobs query results. |
| 141 | + prompt_logprobs, sample_logprobs = _get_logprobs( |
| 142 | + logprobs, sampling_metadata, sample_results |
| 143 | + ) |
| 144 | + return _build_sampler_output( |
| 145 | + sample_results, sampling_metadata, prompt_logprobs, sample_logprobs |
| 146 | + ) |
| 147 | + |
| 148 | + |
| 149 | +class JSONLogitsProcessor: |
| 150 | + def __init__(self, schema, llm): |
| 151 | + """Compile the FSM that drives the JSON-guided generation. |
| 152 | +
|
| 153 | + Parameters |
| 154 | + ---------- |
| 155 | + pydantic_model |
| 156 | + A Pydantic `BaseModel` that encodes the structure we want |
| 157 | + the model to generate. |
| 158 | + llm |
| 159 | + An instance of `vllm.LLM` |
| 160 | +
|
| 161 | + """ |
| 162 | + if isinstance(schema, dict): |
| 163 | + schema = json.dumps(schema) |
| 164 | + regex_str = build_regex_from_object(schema) |
| 165 | + tokenizer = self.adapt_tokenizer(llm.tokenizer) |
| 166 | + |
| 167 | + fsm = RegexFSM(regex_str, tokenizer) |
| 168 | + self.fsm = fsm |
| 169 | + |
| 170 | + def __call__( |
| 171 | + self, seq_id: int, input_ids: List[int], scores: torch.Tensor |
| 172 | + ) -> torch.Tensor: |
| 173 | + """Use the FSM to bias the logits before sampling the next token.""" |
| 174 | + |
| 175 | + if len(input_ids) == 0: # Initialize the fsm states |
| 176 | + self.fsm_state: DefaultDict[int, int] = defaultdict(int) |
| 177 | + else: |
| 178 | + last_token = input_ids[-1] |
| 179 | + self.fsm_state[seq_id] = self.fsm.next_state( |
| 180 | + self.fsm_state[seq_id], last_token |
| 181 | + ) |
| 182 | + |
| 183 | + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) |
| 184 | + |
| 185 | + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) |
| 186 | + mask[allowed_tokens] = 0 |
| 187 | + biased_scores = scores + mask |
| 188 | + |
| 189 | + return biased_scores |
| 190 | + |
| 191 | + def adapt_tokenizer(self, tokenizer): |
| 192 | + """Adapt vLLM's tokenizer to use to compile the FSM. |
| 193 | +
|
| 194 | + The API of Outlines tokenizers is slightly different to that of |
| 195 | + `transformers`. In addition we need to handle the missing spaces to |
| 196 | + Llama's tokenizer to be able to compile FSMs for this model. |
| 197 | +
|
| 198 | + """ |
| 199 | + tokenizer.vocabulary = tokenizer.get_vocab() |
| 200 | + tokenizer.special_tokens = set(tokenizer.all_special_tokens) |
| 201 | + |
| 202 | + def convert_token_to_string(token: str) -> str: |
| 203 | + from transformers.file_utils import SPIECE_UNDERLINE |
| 204 | + |
| 205 | + string = tokenizer.convert_tokens_to_string([token]) |
| 206 | + |
| 207 | + # A hack to handle missing spaces to HF's Llama tokenizers |
| 208 | + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": |
| 209 | + return " " + string |
| 210 | + |
| 211 | + return string |
| 212 | + |
| 213 | + tokenizer.convert_token_to_string = convert_token_to_string |
| 214 | + |
| 215 | + return tokenizer |
0 commit comments