Skip to content

Commit fe87be3

Browse files
authored
fix(whisper): improve robustness and fix decoder prefix handling (#550)
1 parent 6b55601 commit fe87be3

File tree

1 file changed

+30
-16
lines changed
  • crates/tauri-plugin-ipc-audio-transcription-ort/src/models/whisper

1 file changed

+30
-16
lines changed

crates/tauri-plugin-ipc-audio-transcription-ort/src/models/whisper/whisper.rs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ const fn default_true() -> bool {
3030
true
3131
}
3232

33+
// Task token ids used by the decoder prompt.
34+
// These are the special token ids added to the tokenizer / model vocabulary.
35+
const TASK_TRANSLATE_ID: i64 = 50358;
36+
const TASK_TRANSCRIBE_ID: i64 = 50359;
37+
3338
#[derive(Deserialize, Debug)]
3439
pub struct WhisperConfig {
3540
pub num_mel_bins: i64,
@@ -189,6 +194,7 @@ static WHISPER_TO_LANGUAGE_CODE: std::sync::LazyLock<HashMap<&'static str, &'sta
189194
});
190195

191196
pub fn whisper_language_to_code(language: &str) -> Result<String> {
197+
// Accept either two-letter code (en) or full name (english).
192198
let lower_lang = language.to_lowercase();
193199
if let Some(&code) = WHISPER_TO_LANGUAGE_CODE.get(lower_lang.as_str()) {
194200
return Ok(code.to_string());
@@ -205,13 +211,15 @@ fn get_or_download_file<R: Runtime>(
205211
window: tauri::WebviewWindow<R>,
206212
file_sub_path: &str,
207213
event_name: &str,
208-
) -> Result<PathBuf, hf_hub::api::sync::ApiError> {
214+
) -> Result<PathBuf> {
209215
match cache_repo.get(file_sub_path) {
210216
Some(p) => Ok(p),
211-
None => repo.download_with_progress(
212-
file_sub_path,
213-
create_progress_emitter(window.clone(), event_name, file_sub_path.to_string()),
214-
),
217+
None => repo
218+
.download_with_progress(
219+
file_sub_path,
220+
create_progress_emitter(window.clone(), event_name, file_sub_path.to_string()),
221+
)
222+
.map_err(|e| anyhow!("failed to download {}: {}", file_sub_path, e)),
215223
}
216224
}
217225

@@ -339,9 +347,9 @@ impl Whisper {
339347
) -> Result<Vec<i64>> {
340348
let mut init_tokens = vec![self.config.decoder_start_token_id];
341349
let task_id = if gen_config.task == "translate" {
342-
50358
350+
TASK_TRANSLATE_ID
343351
} else {
344-
50359
352+
TASK_TRANSCRIBE_ID
345353
};
346354

347355
if self.config.is_multilingual {
@@ -388,11 +396,13 @@ impl Whisper {
388396
let owned_input = input_features.to_owned();
389397
let inputs = vec![("input_features", Value::from_array(owned_input)?)];
390398
let encoder_outputs = self.encoder_session.run(inputs)?;
391-
let encoder_hidden_states = encoder_outputs.get("last_hidden_state").unwrap();
399+
let encoder_hidden_states = encoder_outputs
400+
.get("last_hidden_state")
401+
.ok_or_else(|| anyhow!("encoder output did not contain 'last_hidden_state'"))?;
392402

393403
let mut generated_tokens = Vec::new();
394404

395-
// KV Cache
405+
// KV Cache (commented scaffold remains)
396406
// let num_decoder_layers = self.config.decoder_layers as usize;
397407
// let head_dim = self.config.d_model / self.config.decoder_attention_heads;
398408
// let mut past_key_values: Vec<Array4<f32>> = (0..num_decoder_layers * 2)
@@ -434,7 +444,8 @@ impl Whisper {
434444
}
435445

436446
generated_tokens.push(next_token);
437-
decoder_input_ids = vec![next_token];
447+
// keep the full prefix for the next iteration (no KV cache): APPEND rather than replace.
448+
decoder_input_ids.push(next_token);
438449
}
439450

440451
Ok(generated_tokens)
@@ -520,22 +531,25 @@ impl WhisperPipeline {
520531
audio: &[f32],
521532
gen_config: &GenerationConfig,
522533
) -> Result<String> {
523-
// 1. Process the raw audio into a mel spectrogram with the correct shape [80, 3000] for normal, and [128, 3000] for large-v3
534+
// Process the raw audio into a mel spectrogram with the correct shape [80, 3000] for normal, and [128, 3000] for large-v3
524535
let input_features = self.processor.process(audio);
525536

526-
// 2. Add the batch dimension, making the shape [1, 80, 3000] for normal, and [1, 128, 3000] for large-v3
537+
// Add the batch dimension, making the shape [1, 80, 3000] for normal, and [1, 128, 3000] for large-v3
527538
let input_features = input_features.insert_axis(Axis(0));
528539

529-
// 3. Generate tokens. This will now work without a shape error.
540+
// Generate tokens. This will now work without a shape error.
530541
let generated_tokens = self
531542
.model
532543
.generate(input_features.view(), gen_config)?;
533544

534-
// The rest of the function remains the same...
545+
// Convert tokens safely to u32 for tokenizer.decode
535546
let generated_tokens_u32: Vec<u32> = generated_tokens
536547
.iter()
537-
.map(|&x| u32::try_from(x).unwrap())
538-
.collect();
548+
.map(|&tok| {
549+
u32::try_from(tok)
550+
.map_err(|e| anyhow!("token id out of range when converting to u32: {} ({})", tok, e))
551+
})
552+
.collect::<Result<_, _>>()?;
539553

540554
let transcript = self
541555
.tokenizer

0 commit comments

Comments
 (0)