@@ -30,6 +30,11 @@ const fn default_true() -> bool {
3030 true
3131}
3232
33+ // Task token ids used by the decoder prompt.
34+ // These are the special token ids added to the tokenizer / model vocabulary.
35+ const TASK_TRANSLATE_ID : i64 = 50358 ;
36+ const TASK_TRANSCRIBE_ID : i64 = 50359 ;
37+
3338#[ derive( Deserialize , Debug ) ]
3439pub struct WhisperConfig {
3540 pub num_mel_bins : i64 ,
@@ -189,6 +194,7 @@ static WHISPER_TO_LANGUAGE_CODE: std::sync::LazyLock<HashMap<&'static str, &'sta
189194 } ) ;
190195
191196pub fn whisper_language_to_code ( language : & str ) -> Result < String > {
197+ // Accept either two-letter code (en) or full name (english).
192198 let lower_lang = language. to_lowercase ( ) ;
193199 if let Some ( & code) = WHISPER_TO_LANGUAGE_CODE . get ( lower_lang. as_str ( ) ) {
194200 return Ok ( code. to_string ( ) ) ;
@@ -205,13 +211,15 @@ fn get_or_download_file<R: Runtime>(
205211 window : tauri:: WebviewWindow < R > ,
206212 file_sub_path : & str ,
207213 event_name : & str ,
208- ) -> Result < PathBuf , hf_hub :: api :: sync :: ApiError > {
214+ ) -> Result < PathBuf > {
209215 match cache_repo. get ( file_sub_path) {
210216 Some ( p) => Ok ( p) ,
211- None => repo. download_with_progress (
212- file_sub_path,
213- create_progress_emitter ( window. clone ( ) , event_name, file_sub_path. to_string ( ) ) ,
214- ) ,
217+ None => repo
218+ . download_with_progress (
219+ file_sub_path,
220+ create_progress_emitter ( window. clone ( ) , event_name, file_sub_path. to_string ( ) ) ,
221+ )
222+ . map_err ( |e| anyhow ! ( "failed to download {}: {}" , file_sub_path, e) ) ,
215223 }
216224}
217225
@@ -339,9 +347,9 @@ impl Whisper {
339347 ) -> Result < Vec < i64 > > {
340348 let mut init_tokens = vec ! [ self . config. decoder_start_token_id] ;
341349 let task_id = if gen_config. task == "translate" {
342- 50358
350+ TASK_TRANSLATE_ID
343351 } else {
344- 50359
352+ TASK_TRANSCRIBE_ID
345353 } ;
346354
347355 if self . config . is_multilingual {
@@ -388,11 +396,13 @@ impl Whisper {
388396 let owned_input = input_features. to_owned ( ) ;
389397 let inputs = vec ! [ ( "input_features" , Value :: from_array( owned_input) ?) ] ;
390398 let encoder_outputs = self . encoder_session . run ( inputs) ?;
391- let encoder_hidden_states = encoder_outputs. get ( "last_hidden_state" ) . unwrap ( ) ;
399+ let encoder_hidden_states = encoder_outputs
400+ . get ( "last_hidden_state" )
401+ . ok_or_else ( || anyhow ! ( "encoder output did not contain 'last_hidden_state'" ) ) ?;
392402
393403 let mut generated_tokens = Vec :: new ( ) ;
394404
395- // KV Cache
405+ // KV Cache (commented scaffold remains)
396406 // let num_decoder_layers = self.config.decoder_layers as usize;
397407 // let head_dim = self.config.d_model / self.config.decoder_attention_heads;
398408 // let mut past_key_values: Vec<Array4<f32>> = (0..num_decoder_layers * 2)
@@ -434,7 +444,8 @@ impl Whisper {
434444 }
435445
436446 generated_tokens. push ( next_token) ;
437- decoder_input_ids = vec ! [ next_token] ;
447+ // keep the full prefix for the next iteration (no KV cache): APPEND rather than replace.
448+ decoder_input_ids. push ( next_token) ;
438449 }
439450
440451 Ok ( generated_tokens)
@@ -520,22 +531,25 @@ impl WhisperPipeline {
520531 audio : & [ f32 ] ,
521532 gen_config : & GenerationConfig ,
522533 ) -> Result < String > {
523- // 1. Process the raw audio into a mel spectrogram with the correct shape [80, 3000] for normal, and [128, 3000] for large-v3
534+ // Process the raw audio into a mel spectrogram with the correct shape [80, 3000] for normal, and [128, 3000] for large-v3
524535 let input_features = self . processor . process ( audio) ;
525536
526- // 2. Add the batch dimension, making the shape [1, 80, 3000] for normal, and [1, 128, 3000] for large-v3
537+ // Add the batch dimension, making the shape [1, 80, 3000] for normal, and [1, 128, 3000] for large-v3
527538 let input_features = input_features. insert_axis ( Axis ( 0 ) ) ;
528539
529- // 3. Generate tokens. This will now work without a shape error.
540+ // Generate tokens. This will now work without a shape error.
530541 let generated_tokens = self
531542 . model
532543 . generate ( input_features. view ( ) , gen_config) ?;
533544
534- // The rest of the function remains the same...
545+ // Convert tokens safely to u32 for tokenizer.decode
535546 let generated_tokens_u32: Vec < u32 > = generated_tokens
536547 . iter ( )
537- . map ( |& x| u32:: try_from ( x) . unwrap ( ) )
538- . collect ( ) ;
548+ . map ( |& tok| {
549+ u32:: try_from ( tok)
550+ . map_err ( |e| anyhow ! ( "token id out of range when converting to u32: {} ({})" , tok, e) )
551+ } )
552+ . collect :: < Result < _ , _ > > ( ) ?;
539553
540554 let transcript = self
541555 . tokenizer
0 commit comments