@@ -85,11 +85,12 @@ def forward(
85
85
v = self .value (x if xa is None else xa )
86
86
if kv_cache is not None and k .shape [1 ] <= self .n_ctx :
87
87
# here is hard coded
88
- # tiny.en: 4
89
- # base.en: 6
90
- # small.en: 12
91
- # medium.en: 24
92
- key_id = self .layer_id - 24
88
+ # tiny: 4
89
+ # base: 6
90
+ # small: 12
91
+ # medium: 24
92
+ # large: 32
93
+ key_id = self .layer_id - 4
93
94
value_id = key_id + 1
94
95
size = k .shape [1 ]
95
96
kv_cache [key_id , :, - size :, :] = k
@@ -215,10 +216,8 @@ def __init__(self, model: str):
215
216
216
217
self .core = Core ()
217
218
self ._model = self .core .read_model (
218
- # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
219
- # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
220
- "encoder.xml" ,
221
- "encoder.bin" ,
219
+ hf_hub_download (repo_id = f"zhuzilin/whisper-openvino-{ model } " , filename = "encoder.xml" ),
220
+ hf_hub_download (repo_id = f"zhuzilin/whisper-openvino-{ model } " , filename = "encoder.bin" ),
222
221
)
223
222
self .model = self .core .compile_model (self ._model , "CPU" )
224
223
@@ -233,10 +232,8 @@ def __init__(self, model: str):
233
232
234
233
self .core = Core ()
235
234
self ._model = self .core .read_model (
236
- # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
237
- # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
238
- "decoder.xml" ,
239
- "decoder.bin" ,
235
+ hf_hub_download (repo_id = f"zhuzilin/whisper-openvino-{ model } " , filename = "decoder.xml" ),
236
+ hf_hub_download (repo_id = f"zhuzilin/whisper-openvino-{ model } " , filename = "decoder.bin" ),
240
237
)
241
238
self .model = self .core .compile_model (self ._model , "CPU" )
242
239
@@ -278,10 +275,16 @@ def embed_audio(self, mel: torch.Tensor):
278
275
return self .encoder .forward (mel )
279
276
280
277
def logits (self , tokens : torch .Tensor , audio_features : torch .Tensor ):
281
- return self .decoder .forward (tokens , audio_features )
278
+ kv_cache = self .new_kv_cache (tokens .shape [0 ], tokens .shape [- 1 ])
279
+ output , _ = self .decoder .forward (tokens , audio_features , kv_cache = torch .from_numpy (kv_cache ), offset = 0 )
280
+ # output, _ = self.decoder.forward(tokens, audio_features, kv_cache=kv_cache, offset=0)
281
+ return output
282
282
283
283
def forward (self , mel : torch .Tensor , tokens : torch .Tensor ) -> Dict [str , torch .Tensor ]:
284
- return self .decoder (tokens , self .encoder (mel ))
284
+ kv_cache = self .new_kv_cache (tokens .shape [0 ], tokens .shape [- 1 ])
285
+ output , _ = self .decoder (tokens , self .encoder (mel ), kv_cache = torch .from_numpy (kv_cache ), offset = 0 )
286
+ # output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=kv_cache, offset=0)
287
+ return output
285
288
286
289
@property
287
290
def device (self ):
@@ -291,6 +294,21 @@ def device(self):
291
294
def is_multilingual (self ):
292
295
return self .dims .n_vocab == 51865
293
296
297
+ def new_kv_cache (self , n_group : int , length : int ):
298
+ if self .type == "tiny.en" or self .type == "tiny" :
299
+ size = [8 , n_group , length , 384 ]
300
+ elif self .type == "base.en" or self .type == "base" :
301
+ size = [12 , n_group , length , 512 ]
302
+ elif self .type == "small.en" or self .type == "small" :
303
+ size = [24 , n_group , length , 768 ]
304
+ elif self .type == "medium.en" or self .type == "medium" :
305
+ size = [48 , n_group , length , 1024 ]
306
+ elif self .type == "large" :
307
+ size = [64 , n_group , length , 1280 ]
308
+ else :
309
+ raise ValueError (f"Unsupported model type: { self .type } " )
310
+ return np .zeros (size , dtype = np .float32 )
311
+
294
312
detect_language = detect_language_function
295
313
transcribe = transcribe_function
296
314
decode = decode_function
0 commit comments