Skip to content

Commit 5112758

Browse files
add tokenize test script
1 parent 19d60ed commit 5112758

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

triton_tokenizer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import tritonclient.http
2+
import numpy as np
3+
from utils import print_timings, setup_logging, track_infer_time
4+
5+
model_name = 'tokenize'
6+
url = '127.0.0.1:8000'
7+
model_version = '1'
8+
text = "SOME TEXT" # edit to check longer sequence length
9+
batch_size = 1
10+
11+
setup_logging()
12+
triton_client = tritonclient.http.InferenceServerClient(url=url, verbose=False)
13+
time_buffer = list()
14+
15+
query = tritonclient.http.InferInput(name='TEXT', shape=(batch_size,), datatype="BYTES")
16+
input_ids = tritonclient.http.InferRequestedOutput('INPUT_IDS', binary_data=False)
17+
attention = tritonclient.http.InferRequestedOutput('ATTENTION', binary_data=False)
18+
19+
20+
def perform_inference():
21+
query.set_data_from_numpy(np.asarray([text] * batch_size, dtype=object))
22+
triton_client.infer(model_name, model_version=model_version, inputs=[query], outputs=[input_ids, attention])
23+
24+
25+
# warmup
26+
for _ in range(10000):
27+
perform_inference()
28+
29+
for _ in range(1000):
30+
with track_infer_time(time_buffer):
31+
perform_inference()
32+
33+
print_timings(name=f"tokenize, # text len: {len(text)}", timings=time_buffer)

0 commit comments

Comments
 (0)