@@ -133,41 +133,56 @@ def __init__(self, model: "Whisper", initial_token_length: int):
133
133
self .model : "Whisper" = model
134
134
self .initial_token_length = initial_token_length
135
135
self .kv_cache = None
136
- self .export_onnx = False
136
+ if model .type == "tiny.en" :
137
+ self .kv_cache_size = lambda x , y : [8 , x , y , 384 ]
138
+ elif model .type == "base.en" :
139
+ self .kv_cache_size = lambda x , y : [12 , x , y , 512 ]
140
+ elif model .type == "small.en" :
141
+ self .kv_cache_size = lambda x , y : [24 , x , y , 768 ]
142
+ elif model .type == "medium.en" :
143
+ self .kv_cache_size = lambda x , y : [48 , x , y , 1024 ]
144
+ else :
145
+ raise ValueError (f"Unsupported model type: { model .type } " )
137
146
138
147
def logits (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
148
+ n_group = tokens .shape [0 ]
139
149
if self .kv_cache is None :
140
- # hard code for decoder layer 4, 6, 8, 10
141
- self . kv_cache = np . zeros ([ 8 , 5 , self .initial_token_length , 384 ] , dtype = np .float32 )
150
+ self . kv_cache = np . zeros (
151
+ self . kv_cache_size ( n_group , self .initial_token_length ) , dtype = np .float32 )
142
152
offset = 0
143
153
else :
144
154
offset = self .kv_cache .shape [2 ]
145
- new_kv_cache = np .zeros ([ 8 , 5 , offset + 1 , 384 ] , dtype = np .float32 )
155
+ new_kv_cache = np .zeros (self . kv_cache_size ( n_group , offset + 1 ) , dtype = np .float32 )
146
156
new_kv_cache [:, :, :- 1 , :] = self .kv_cache
147
157
self .kv_cache = new_kv_cache
148
158
149
159
if tokens .shape [- 1 ] > self .initial_token_length :
150
160
# only need to use the last token except in the first forward pass
151
161
tokens = tokens [:, - 1 :]
152
162
153
- if self .export_onnx and self .kv_cache .shape [2 ] > self .initial_token_length :
163
+ # export decoder as onnx
164
+ if False and self .kv_cache .shape [2 ] > self .initial_token_length :
165
+ print (f"tokens: { tokens .shape } " )
166
+ print (f"audio_features: { audio_features .shape } " )
167
+ print (f"kv_cache: { self .kv_cache .shape } " )
154
168
torch .onnx .export (
155
169
self .model .decoder ,
156
170
(tokens , audio_features , torch .from_numpy (self .kv_cache ), torch .tensor (offset )),
157
171
"decoder.onnx" ,
158
- verbose = True ,
172
+ verbose = False ,
159
173
opset_version = 13 ,
160
174
input_names = ["tokens" , "audio_features" , "kv_cache" , "offset" ],
161
175
output_names = ["logits" , "output_kv_cache" ],
162
176
dynamic_axes = {
163
- "tokens" : [1 ],
164
- "kv_cache" : [2 ],
177
+ "tokens" : [0 , 1 ],
178
+ "audio_features" : [0 ],
179
+ "kv_cache" : [1 , 2 ],
165
180
"output_kv_cache" : [2 ],
166
181
}
167
182
)
168
183
exit ()
169
- output , self .kv_cache = self .model .decoder (tokens , audio_features , kv_cache = self .kv_cache , offset = offset )
170
- # output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=torch.from_numpy(self.kv_cache), offset=torch.tensor(offset))
184
+ # output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache, offset=offset)
185
+ output , self .kv_cache = self .model .decoder (tokens , audio_features , kv_cache = torch .from_numpy (self .kv_cache ), offset = torch .tensor (offset ))
171
186
return output
172
187
173
188
def cleanup_caching (self ):
@@ -578,6 +593,7 @@ def _get_audio_features(self, mel: Tensor):
578
593
# encoded audio features are given; skip audio encoding
579
594
audio_features = mel
580
595
else :
596
+ # # export encoder as onnx
581
597
# torch.onnx.export(
582
598
# self.model.encoder,
583
599
# (mel),
@@ -615,6 +631,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
615
631
try :
616
632
for i in range (self .sample_len ):
617
633
logits = self .inference .logits (tokens , audio_features )
634
+ print (f"step: { i } , logits: { logits } " , flush = True )
618
635
619
636
if i == 0 and self .tokenizer .no_speech is not None : # save no_speech_probs
620
637
probs_at_sot = logits [:, self .sot_index ].float ().softmax (dim = - 1 )
0 commit comments