|
2 | 2 | import json
|
3 | 3 | import math
|
4 | 4 | from collections import defaultdict
|
5 |
| -from typing import DefaultDict, List, Optional |
| 5 | +from typing import DefaultDict, List |
6 | 6 |
|
7 | 7 | import torch
|
8 |
| -import torch.nn as nn |
9 |
| -from vllm.model_executor.layers.sampler import ( |
10 |
| - _SAMPLING_EPS, |
11 |
| - _apply_min_p, |
12 |
| - _apply_penalties, |
13 |
| - _apply_top_p_top_k, |
14 |
| - _build_sampler_output, |
15 |
| - _get_logits, |
16 |
| - _get_logprobs, |
17 |
| - _get_penalties, |
18 |
| - _get_temperatures, |
19 |
| - _get_top_p_top_k_min_p, |
20 |
| - _prune_hidden_states, |
21 |
| - _sample, |
22 |
| -) |
23 | 8 |
|
24 | 9 | from outlines.fsm.fsm import RegexFSM
|
25 | 10 | from outlines.fsm.json_schema import build_regex_from_object
|
@@ -54,98 +39,6 @@ def _patched_apply_logits_processors(
|
54 | 39 | return logits
|
55 | 40 |
|
56 | 41 |
|
57 |
| -class PatchedSampler(nn.Module): |
58 |
| - """This code is copied from vLLM and uses the patched logits processor. |
59 |
| -
|
60 |
| - Samples the next tokens from the model's outputs. |
61 |
| -
|
62 |
| - This layer does the following: |
63 |
| - 1. Discard the hidden states that are not used for sampling (i.e., all |
64 |
| - tokens except the final one in each prompt). |
65 |
| - 2. Compute the logits for the next tokens. |
66 |
| - 3. Apply presence, frequency and repetition penalties. |
67 |
| - 4. Apply temperature scaling. |
68 |
| - 5. Apply top-p and top-k truncation. |
69 |
| - 6. Sample the next tokens. |
70 |
| -
|
71 |
| - Here, each sequence group within the batch can have different sampling |
72 |
| - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). |
73 |
| - """ |
74 |
| - |
75 |
| - def __init__(self, vocab_size: int) -> None: |
76 |
| - super().__init__() |
77 |
| - self.vocab_size = vocab_size |
78 |
| - |
79 |
| - def forward( |
80 |
| - self, |
81 |
| - embedding: torch.Tensor, |
82 |
| - hidden_states: torch.Tensor, |
83 |
| - sampling_metadata, |
84 |
| - embedding_bias: Optional[torch.Tensor] = None, |
85 |
| - ): |
86 |
| - # Get the hidden states that we use for sampling. |
87 |
| - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
88 |
| - |
89 |
| - # Get the logits for the next tokens. |
90 |
| - logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) |
91 |
| - |
92 |
| - # Apply logits processors (if any). |
93 |
| - logits = _patched_apply_logits_processors(logits, sampling_metadata) |
94 |
| - # Apply presence and frequency penalties. |
95 |
| - presence_penalties, frequency_penalties, repetition_penalties = _get_penalties( |
96 |
| - sampling_metadata |
97 |
| - ) |
98 |
| - assert len(presence_penalties) == logits.shape[0] |
99 |
| - assert len(frequency_penalties) == logits.shape[0] |
100 |
| - assert len(repetition_penalties) == logits.shape[0] |
101 |
| - logits = _apply_penalties( |
102 |
| - logits, |
103 |
| - sampling_metadata, |
104 |
| - presence_penalties, |
105 |
| - frequency_penalties, |
106 |
| - repetition_penalties, |
107 |
| - ) |
108 |
| - |
109 |
| - # Apply temperature scaling. |
110 |
| - temperatures = _get_temperatures(sampling_metadata) |
111 |
| - assert len(temperatures) == logits.shape[0] |
112 |
| - if any(t != 1.0 for t in temperatures): |
113 |
| - t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) |
114 |
| - # Use in-place division to avoid creating a new tensor. |
115 |
| - logits.div_(t.unsqueeze(dim=1)) |
116 |
| - |
117 |
| - # Apply top-p and top-k truncation. |
118 |
| - top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( |
119 |
| - sampling_metadata, self.vocab_size |
120 |
| - ) |
121 |
| - assert len(top_ps) == len(top_ks) == logits.shape[0] |
122 |
| - do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) |
123 |
| - do_top_k = any(k != self.vocab_size for k in top_ks) |
124 |
| - if do_top_p or do_top_k: |
125 |
| - logits = _apply_top_p_top_k(logits, top_ps, top_ks) |
126 |
| - |
127 |
| - do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) |
128 |
| - if do_min_p: |
129 |
| - logits = _apply_min_p(logits, min_ps) |
130 |
| - |
131 |
| - # We use float32 for probabilities and log probabilities. |
132 |
| - # Compute the probabilities. |
133 |
| - probs = torch.softmax(logits, dim=-1, dtype=torch.float) |
134 |
| - # Compute the log probabilities. |
135 |
| - # Use log_softmax to ensure numerical stability. |
136 |
| - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) |
137 |
| - |
138 |
| - # Sample the next tokens. |
139 |
| - sample_results = _sample(probs, logprobs, sampling_metadata) |
140 |
| - # Get the logprobs query results. |
141 |
| - prompt_logprobs, sample_logprobs = _get_logprobs( |
142 |
| - logprobs, sampling_metadata, sample_results |
143 |
| - ) |
144 |
| - return _build_sampler_output( |
145 |
| - sample_results, sampling_metadata, prompt_logprobs, sample_logprobs |
146 |
| - ) |
147 |
| - |
148 |
| - |
149 | 42 | class JSONLogitsProcessor:
|
150 | 43 | def __init__(self, schema, llm):
|
151 | 44 | """Compile the FSM that drives the JSON-guided generation.
|
|
0 commit comments