|
8 | 8 | import lorem
|
9 | 9 | import pydantic
|
10 | 10 | import requests
|
| 11 | +import sseclient |
| 12 | +import transformers |
11 | 13 | import websocket
|
| 14 | +from chat_chain_prompts import V2_PROMPTER_PREFIX |
12 | 15 | from loguru import logger
|
13 | 16 | from oasst_shared.schemas import inference
|
| 17 | +from settings import settings |
| 18 | + |
| 19 | +shared_tokenizer_lock = threading.Lock() |
14 | 20 |
|
15 | 21 |
|
16 | 22 | class TokenBuffer:
|
@@ -58,6 +64,42 @@ def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]) -> Ite
|
58 | 64 | yield from self.tokens
|
59 | 65 |
|
60 | 66 |
|
| 67 | +def truncate_prompt( |
| 68 | + tokenizer: transformers.PreTrainedTokenizer, |
| 69 | + worker_config: inference.WorkerConfig, |
| 70 | + parameters: interface.GenerateStreamParameters, |
| 71 | + prompt: str, |
| 72 | + plugin_used: bool, |
| 73 | +): |
| 74 | + with shared_tokenizer_lock: |
| 75 | + ids = tokenizer.encode(prompt) |
| 76 | + |
| 77 | + max_input_length = worker_config.model_config.max_input_length |
| 78 | + |
| 79 | + # make room for prompter prefix |
| 80 | + if plugin_used: |
| 81 | + max_input_length = max_input_length - 1 |
| 82 | + |
| 83 | + max_total_tokens = worker_config.model_config.max_total_length |
| 84 | + if len(ids) > max_input_length: |
| 85 | + logger.warning(f"Prompt too long, left-truncating to {max_input_length} tokens") |
| 86 | + ids = ids[-(max_input_length - 1) :] |
| 87 | + with shared_tokenizer_lock: |
| 88 | + prompt = tokenizer.decode(ids) |
| 89 | + # If there is no prompter prefix, due to truncation, add it back. |
| 90 | + if V2_PROMPTER_PREFIX not in prompt: |
| 91 | + prompt = V2_PROMPTER_PREFIX + prompt |
| 92 | + |
| 93 | + input_length = len(ids) |
| 94 | + spare = max_total_tokens - input_length - 1 |
| 95 | + if not parameters.max_new_tokens: |
| 96 | + parameters.max_new_tokens = spare |
| 97 | + elif parameters.max_new_tokens > spare: |
| 98 | + logger.warning(f"Max new tokens too high, reducing to {spare}") |
| 99 | + parameters.max_new_tokens = spare |
| 100 | + return prompt |
| 101 | + |
| 102 | + |
61 | 103 | def wait_for_inference_server(http: "HttpClient", timeout: int = 600):
|
62 | 104 | time_limit = time.time() + timeout
|
63 | 105 | while True:
|
@@ -136,3 +178,34 @@ def get(self, path: str, **kwargs):
|
136 | 178 |
|
137 | 179 | def post(self, path: str, **kwargs):
|
138 | 180 | return requests.post(self.base_url + path, auth=self.auth, **kwargs)
|
| 181 | + |
| 182 | + |
| 183 | +def get_inference_server_stream_events(request: interface.GenerateStreamRequest): |
| 184 | + http = HttpClient( |
| 185 | + base_url=settings.inference_server_url, |
| 186 | + basic_auth_username=settings.basic_auth_username, |
| 187 | + basic_auth_password=settings.basic_auth_password, |
| 188 | + ) |
| 189 | + response = http.post( |
| 190 | + "/generate_stream", |
| 191 | + json=request.dict(), |
| 192 | + stream=True, |
| 193 | + headers={"Accept": "text/event-stream"}, |
| 194 | + ) |
| 195 | + try: |
| 196 | + response.raise_for_status() |
| 197 | + except requests.HTTPError: |
| 198 | + logger.exception("Failed to get response from inference server") |
| 199 | + logger.error(f"Response: {response.text}") |
| 200 | + raise |
| 201 | + |
| 202 | + client = sseclient.SSEClient(response) |
| 203 | + for event in client.events(): |
| 204 | + if event.event == "error": |
| 205 | + logger.error(f"Error from inference server: {event.data}") |
| 206 | + yield interface.GenerateStreamResponse(error=event.data) |
| 207 | + raise RuntimeError(f"Error from inference server: {event.data}") |
| 208 | + if event.event == "ping": |
| 209 | + continue |
| 210 | + stream_response = interface.GenerateStreamResponse.parse_raw(event.data) |
| 211 | + yield stream_response |
0 commit comments