Skip to content

Commit edb6944

Browse files
committed
support detect language
1 parent 3a0e935 commit edb6944

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

whisper/decoding.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,26 +133,15 @@ def __init__(self, model: "Whisper", initial_token_length: int):
133133
self.model: "Whisper" = model
134134
self.initial_token_length = initial_token_length
135135
self.kv_cache = None
136-
if model.type == "tiny.en":
137-
self.kv_cache_size = lambda x, y: [8, x, y, 384]
138-
elif model.type == "base.en":
139-
self.kv_cache_size = lambda x, y: [12, x, y, 512]
140-
elif model.type == "small.en":
141-
self.kv_cache_size = lambda x, y: [24, x, y, 768]
142-
elif model.type == "medium.en":
143-
self.kv_cache_size = lambda x, y: [48, x, y, 1024]
144-
else:
145-
raise ValueError(f"Unsupported model type: {model.type}")
146136

147137
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
148138
n_group = tokens.shape[0]
149139
if self.kv_cache is None:
150-
self.kv_cache = np.zeros(
151-
self.kv_cache_size(n_group, self.initial_token_length), dtype=np.float32)
140+
self.kv_cache = self.model.new_kv_cache(n_group, self.initial_token_length)
152141
offset = 0
153142
else:
154143
offset = self.kv_cache.shape[2]
155-
new_kv_cache = np.zeros(self.kv_cache_size(n_group, offset + 1), dtype=np.float32)
144+
new_kv_cache = self.model.new_kv_cache(n_group, offset + 1)
156145
new_kv_cache[:, :, :-1, :] = self.kv_cache
157146
self.kv_cache = new_kv_cache
158147

@@ -161,7 +150,7 @@ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
161150
tokens = tokens[:, -1:]
162151

163152
# export decoder as onnx
164-
if False and self.kv_cache.shape[2] > self.initial_token_length:
153+
if True and self.kv_cache.shape[2] > self.initial_token_length:
165154
print(f"tokens: {tokens.shape}")
166155
print(f"audio_features: {audio_features.shape}")
167156
print(f"kv_cache: {self.kv_cache.shape}")
@@ -631,7 +620,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
631620
try:
632621
for i in range(self.sample_len):
633622
logits = self.inference.logits(tokens, audio_features)
634-
print(f"step: {i}, logits: {logits}", flush=True)
623+
print(f"step: {i}", flush=True)
635624

636625
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
637626
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)

whisper/model.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,12 @@ def forward(
8585
v = self.value(x if xa is None else xa)
8686
if kv_cache is not None and k.shape[1] <= self.n_ctx:
8787
# here is hard coded
88-
# tiny.en: 4
89-
# base.en: 6
90-
# small.en: 12
91-
# medium.en: 24
92-
key_id = self.layer_id - 24
88+
# tiny: 4
89+
# base: 6
90+
# small: 12
91+
# medium: 24
92+
# large: 32
93+
key_id = self.layer_id - 4
9394
value_id = key_id + 1
9495
size = k.shape[1]
9596
kv_cache[key_id, :, -size:, :] = k
@@ -215,10 +216,8 @@ def __init__(self, model: str):
215216

216217
self.core = Core()
217218
self._model = self.core.read_model(
218-
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
219-
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
220-
"encoder.xml",
221-
"encoder.bin",
219+
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
220+
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
222221
)
223222
self.model = self.core.compile_model(self._model, "CPU")
224223

@@ -233,10 +232,8 @@ def __init__(self, model: str):
233232

234233
self.core = Core()
235234
self._model = self.core.read_model(
236-
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
237-
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
238-
"decoder.xml",
239-
"decoder.bin",
235+
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
236+
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
240237
)
241238
self.model = self.core.compile_model(self._model, "CPU")
242239

@@ -278,10 +275,16 @@ def embed_audio(self, mel: torch.Tensor):
278275
return self.encoder.forward(mel)
279276

280277
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
281-
return self.decoder.forward(tokens, audio_features)
278+
kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1])
279+
output, _ = self.decoder.forward(tokens, audio_features, kv_cache=torch.from_numpy(kv_cache), offset=0)
280+
# output, _ = self.decoder.forward(tokens, audio_features, kv_cache=kv_cache, offset=0)
281+
return output
282282

283283
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
284-
return self.decoder(tokens, self.encoder(mel))
284+
kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1])
285+
output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=torch.from_numpy(kv_cache), offset=0)
286+
# output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=kv_cache, offset=0)
287+
return output
285288

286289
@property
287290
def device(self):
@@ -291,6 +294,21 @@ def device(self):
291294
def is_multilingual(self):
292295
return self.dims.n_vocab == 51865
293296

297+
def new_kv_cache(self, n_group: int, length: int):
298+
if self.type == "tiny.en" or self.type == "tiny":
299+
size = [8, n_group, length, 384]
300+
elif self.type == "base.en" or self.type == "base":
301+
size = [12, n_group, length, 512]
302+
elif self.type == "small.en" or self.type == "small":
303+
size = [24, n_group, length, 768]
304+
elif self.type == "medium.en" or self.type == "medium":
305+
size = [48, n_group, length, 1024]
306+
elif self.type == "large":
307+
size = [64, n_group, length, 1280]
308+
else:
309+
raise ValueError(f"Unsupported model type: {self.type}")
310+
return np.zeros(size, dtype=np.float32)
311+
294312
detect_language = detect_language_function
295313
transcribe = transcribe_function
296314
decode = decode_function

0 commit comments

Comments
 (0)