Skip to content

Commit 5dd1025

Browse files
Fix plugins (#3017)
Plugins was compatible only with `text-generation-inference` based workers and therefore worked on Dragan's machines but did not work on OA prod. This resolves the incompatibility. --------- Co-authored-by: draganjovanovich <[email protected]>
1 parent 9bcc916 commit 5dd1025

File tree

9 files changed

+173
-106
lines changed

9 files changed

+173
-106
lines changed

inference/worker/basic_hf_server.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import transformers
1212
import uvicorn
1313
from fastapi.middleware.cors import CORSMiddleware
14+
from hf_stopping import SequenceStoppingCriteria
1415
from loguru import logger
1516
from oasst_shared import model_configs
1617
from settings import settings
@@ -60,8 +61,12 @@ def model_thread():
6061
prompt = request.inputs
6162
params = request.parameters.dict()
6263
seed = params.pop("seed")
63-
params.pop("stop")
64+
stop_sequences = params.pop("stop")
6465
params.pop("details")
66+
params.pop("plugins")
67+
68+
if seed is not None:
69+
torch.manual_seed(seed)
6570

6671
last_token_id = None # need to delay by 1 to simulate tgi
6772

@@ -79,7 +84,18 @@ def print_text(token_id: int):
7984
ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
8085
streamer = hf_streamer.HFStreamer(input_ids=ids, printer=print_text)
8186
ids = ids.to(model.device)
82-
output = model.generate(ids, **params, streamer=streamer, eos_token_id=eos_token_id)
87+
stopping_criteria = (
88+
transformers.StoppingCriteriaList([SequenceStoppingCriteria(tokenizer, stop_sequences, prompt)])
89+
if stop_sequences
90+
else None
91+
)
92+
output = model.generate(
93+
ids,
94+
**params,
95+
streamer=streamer,
96+
eos_token_id=eos_token_id,
97+
stopping_criteria=stopping_criteria,
98+
)
8399
output = output.cpu()
84100
output_ids = output[0][len(ids[0]) :]
85101
decoded = tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -130,7 +146,7 @@ def decode_token(token_id):
130146
return result[special_decode_token_length:]
131147

132148
config_dtype = hf_config.torch_dtype if hasattr(hf_config, "torch_dtype") else torch.float32
133-
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else config_dtype
149+
dtype = torch.bfloat16 if torch.has_cuda and torch.cuda.is_bf16_supported() else config_dtype
134150

135151
model = transformers.AutoModelForCausalLM.from_pretrained(
136152
model_config.model_id,

inference/worker/chat_chain.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import interface
44
import transformers
5+
import utils
56
from chat_chain_prompts import (
67
ASSISTANT_PREFIX,
78
HUMAN_PREFIX,
@@ -65,6 +66,7 @@ def handle_plugin_usage(
6566
plugin: inference.PluginEntry | None,
6667
worker_config: inference.WorkerConfig,
6768
tokenizer: transformers.PreTrainedTokenizer,
69+
parameters: interface.GenerateStreamParameters,
6870
) -> tuple[str, inference.PluginUsed]:
6971
execution_details = inference.PluginExecutionDetails(
7072
inner_monologue=[],
@@ -115,6 +117,8 @@ def handle_plugin_usage(
115117
# NOTE: Do not strip() any of the outputs ever, as it will degrade the
116118
# instruction following performance, at least with
117119
# `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model`
120+
121+
init_prompt = utils.truncate_prompt(tokenizer, worker_config, parameters, init_prompt, True)
118122
chain_response = (
119123
llm.generate(prompts=[init_prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"])
120124
.generations[0][0]
@@ -159,6 +163,7 @@ def handle_plugin_usage(
159163
# NOTE: Do not strip() any of the outputs ever, as it will degrade the
160164
# instruction following performance, at least with
161165
# `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model`
166+
new_prompt = utils.truncate_prompt(tokenizer, worker_config, parameters, new_prompt, True)
162167
chain_response = (
163168
llm.generate(prompts=[new_prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"])
164169
.generations[0][0]
@@ -311,7 +316,7 @@ def handle_conversation(
311316
# using sampling settings derived from frontend UI
312317
if plugin_enabled:
313318
return handle_plugin_usage(
314-
original_prompt, prompt_template, language, tools, memory, plugin, worker_config, tokenizer
319+
original_prompt, prompt_template, language, tools, memory, plugin, worker_config, tokenizer, parameters
315320
)
316321

317322
# Just regular prompt template without plugin chain.

inference/worker/chat_chain_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import re
3-
import threading
43

54
import requests
65
import transformers
@@ -11,10 +10,9 @@
1110
from langchain.prompts import PromptTemplate
1211
from loguru import logger
1312
from oasst_shared.schemas import inference
14-
from opeanapi_parser import prepare_plugin_for_llm
13+
from openapi_parser import prepare_plugin_for_llm
1514
from settings import settings
16-
17-
tokenizer_lock = threading.Lock()
15+
from utils import shared_tokenizer_lock
1816

1917
RESPONSE_MAX_LENGTH = 2048
2018

@@ -343,7 +341,7 @@ def prepare_prompt(
343341

344342
out_prompt = prompt_template.format(**args)
345343

346-
with tokenizer_lock:
344+
with shared_tokenizer_lock:
347345
ids = tokenizer.encode(out_prompt)
348346

349347
# soft truncation
@@ -362,7 +360,7 @@ def prepare_prompt(
362360

363361
out_prompt = prompt_template.format(**args)
364362

365-
with tokenizer_lock:
363+
with shared_tokenizer_lock:
366364
ids = tokenizer.encode(out_prompt)
367365
logger.warning(f"Prompt too long, deleting chat history. New length: {len(ids)}")
368366

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import interface
2+
import utils
13
from langchain.llms.base import LLM
2-
from text_generation import Client
34

45

56
class HFInference(LLM):
@@ -23,22 +24,30 @@ def _call(self, prompt: str, stop: list[str] | None = None) -> str:
2324
else:
2425
stop += self.stop_sequences
2526

26-
print(stop)
27-
client = Client(self.inference_server_url, timeout=1000)
28-
res = client.generate(
29-
prompt,
30-
stop_sequences=stop,
31-
max_new_tokens=self.max_new_tokens,
32-
top_k=self.top_k,
33-
top_p=self.top_p,
34-
typical_p=self.typical_p,
35-
temperature=self.temperature,
36-
repetition_penalty=self.repetition_penalty,
37-
seed=self.seed,
27+
request = interface.GenerateStreamRequest(
28+
inputs=prompt,
29+
parameters=interface.GenerateStreamParameters(
30+
stop=stop,
31+
max_new_tokens=self.max_new_tokens,
32+
top_k=self.top_k,
33+
top_p=self.top_p,
34+
typical_p=self.typical_p,
35+
temperature=self.temperature,
36+
repetition_penalty=self.repetition_penalty,
37+
seed=self.seed,
38+
),
3839
)
40+
41+
for event in utils.get_inference_server_stream_events(request):
42+
stream_response = event
43+
44+
generated_text = stream_response.generated_text
45+
if generated_text is None:
46+
generated_text = ""
47+
3948
# remove stop sequences from the end of the generated text
4049
for stop_seq in stop:
41-
if stop_seq in res.generated_text:
42-
res.generated_text = res.generated_text[: res.generated_text.index(stop_seq)]
50+
if stop_seq in generated_text:
51+
generated_text = generated_text[: generated_text.index(stop_seq)]
4352

44-
return res.generated_text
53+
return generated_text

inference/worker/hf_stopping.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from tokenizers import Tokenizer
3+
from transformers import StoppingCriteria
4+
5+
6+
class SequenceStoppingCriteria(StoppingCriteria):
7+
def __init__(
8+
self,
9+
tokenizer: Tokenizer,
10+
stop_texts: list[str],
11+
input_prompt: str,
12+
*args,
13+
**kwargs,
14+
):
15+
super().__init__(*args, **kwargs)
16+
self.stop_texts = stop_texts
17+
self.tokenizer = tokenizer
18+
self.input_length = len(tokenizer.encode(input_prompt))
19+
20+
def __call__(
21+
self,
22+
input_ids: torch.LongTensor,
23+
scores: torch.FloatTensor,
24+
**kwargs,
25+
) -> bool:
26+
# Assumes batch size 1, sufficient for our use case
27+
generated_ids = input_ids[0, self.input_length :].tolist()
28+
# TODO: optimise this. Inefficient to decode whole sequence every time
29+
# but can't encode stop sequences as they don't always tokenize the same
30+
generated_text = self.tokenizer.decode(generated_ids)
31+
return any(text in generated_text for text in self.stop_texts)
File renamed without changes.

inference/worker/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,5 @@ pydantic
1010
requests
1111
sentencepiece
1212
sseclient-py
13-
text-generation
1413
git+https://github.com/huggingface/transformers@main#egg=transformers
1514
websocket-client

inference/worker/utils.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@
88
import lorem
99
import pydantic
1010
import requests
11+
import sseclient
12+
import transformers
1113
import websocket
14+
from chat_chain_prompts import V2_PROMPTER_PREFIX
1215
from loguru import logger
1316
from oasst_shared.schemas import inference
17+
from settings import settings
18+
19+
shared_tokenizer_lock = threading.Lock()
1420

1521

1622
class TokenBuffer:
@@ -58,6 +64,42 @@ def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]) -> Ite
5864
yield from self.tokens
5965

6066

67+
def truncate_prompt(
68+
tokenizer: transformers.PreTrainedTokenizer,
69+
worker_config: inference.WorkerConfig,
70+
parameters: interface.GenerateStreamParameters,
71+
prompt: str,
72+
plugin_used: bool,
73+
):
74+
with shared_tokenizer_lock:
75+
ids = tokenizer.encode(prompt)
76+
77+
max_input_length = worker_config.model_config.max_input_length
78+
79+
# make room for prompter prefix
80+
if plugin_used:
81+
max_input_length = max_input_length - 1
82+
83+
max_total_tokens = worker_config.model_config.max_total_length
84+
if len(ids) > max_input_length:
85+
logger.warning(f"Prompt too long, left-truncating to {max_input_length} tokens")
86+
ids = ids[-(max_input_length - 1) :]
87+
with shared_tokenizer_lock:
88+
prompt = tokenizer.decode(ids)
89+
# If there is no prompter prefix, due to truncation, add it back.
90+
if V2_PROMPTER_PREFIX not in prompt:
91+
prompt = V2_PROMPTER_PREFIX + prompt
92+
93+
input_length = len(ids)
94+
spare = max_total_tokens - input_length - 1
95+
if not parameters.max_new_tokens:
96+
parameters.max_new_tokens = spare
97+
elif parameters.max_new_tokens > spare:
98+
logger.warning(f"Max new tokens too high, reducing to {spare}")
99+
parameters.max_new_tokens = spare
100+
return prompt
101+
102+
61103
def wait_for_inference_server(http: "HttpClient", timeout: int = 600):
62104
time_limit = time.time() + timeout
63105
while True:
@@ -136,3 +178,34 @@ def get(self, path: str, **kwargs):
136178

137179
def post(self, path: str, **kwargs):
138180
return requests.post(self.base_url + path, auth=self.auth, **kwargs)
181+
182+
183+
def get_inference_server_stream_events(request: interface.GenerateStreamRequest):
184+
http = HttpClient(
185+
base_url=settings.inference_server_url,
186+
basic_auth_username=settings.basic_auth_username,
187+
basic_auth_password=settings.basic_auth_password,
188+
)
189+
response = http.post(
190+
"/generate_stream",
191+
json=request.dict(),
192+
stream=True,
193+
headers={"Accept": "text/event-stream"},
194+
)
195+
try:
196+
response.raise_for_status()
197+
except requests.HTTPError:
198+
logger.exception("Failed to get response from inference server")
199+
logger.error(f"Response: {response.text}")
200+
raise
201+
202+
client = sseclient.SSEClient(response)
203+
for event in client.events():
204+
if event.event == "error":
205+
logger.error(f"Error from inference server: {event.data}")
206+
yield interface.GenerateStreamResponse(error=event.data)
207+
raise RuntimeError(f"Error from inference server: {event.data}")
208+
if event.event == "ping":
209+
continue
210+
stream_response = interface.GenerateStreamResponse.parse_raw(event.data)
211+
yield stream_response

0 commit comments

Comments
 (0)