Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions apps/stage-tamagotchi/src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,12 @@ async fn stop_click_through(window: tauri::Window) -> Result<(), String> {
Ok(())
}

fn load_whisper_model(window: tauri::Window) -> anyhow::Result<()> {
let progress_manager = crate::whisper::progress::ModelLoadProgressEmitterManager::new(window);

// Determine device to use
let device = if candle_core::utils::cuda_is_available() {
candle_core::Device::new_cuda(0)?
} else if candle_core::utils::metal_is_available() {
candle_core::Device::new_metal(0)?
} else {
candle_core::Device::Cpu
};

let whisper_model = whisper::whisper::WhichWhisperModel::Tiny;
let _ = whisper::whisper::WhisperProcessor::new(whisper_model, device.clone(), progress_manager)?;

Ok(())
}

#[tauri::command]
async fn load_models(window: tauri::Window) -> Result<(), String> {
let _ = load_whisper_model(window);
let device = whisper::model_manager::load_device().unwrap();

whisper::model_manager::load_whisper_model(window.clone(), device.clone()).unwrap();
whisper::model_manager::load_vad_model(window.clone(), candle_core::Device::Cpu).unwrap();
Ok(())
}

Expand Down
2 changes: 2 additions & 0 deletions apps/stage-tamagotchi/src-tauri/src/whisper/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pub mod progress;
pub mod whisper;
pub mod model_manager;
pub mod vad;
43 changes: 43 additions & 0 deletions apps/stage-tamagotchi/src-tauri/src/whisper/model_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use log::info;
use anyhow::Ok;

use crate::whisper::whisper::{WhichWhisperModel, WhisperProcessor};
use crate::whisper::progress::ModelLoadProgressEmitterManager;
use crate::whisper::vad::VADProcessor;

pub fn load_device() -> anyhow::Result<candle_core::Device> {
// Determine device to use
let device = if candle_core::utils::cuda_is_available() {
candle_core::Device::new_cuda(0)?
} else if candle_core::utils::metal_is_available() {
candle_core::Device::new_metal(0)?
} else {
candle_core::Device::Cpu
};

info!("Using device: {device:?}");
Ok(device)
}


pub fn load_whisper_model(window: tauri::Window, device: candle_core::Device) -> anyhow::Result<()> {
let progress_manager = ModelLoadProgressEmitterManager::new(window);

let whisper_model = WhichWhisperModel::Tiny;

info!("Loading whisper model: {:?}", whisper_model);

let _ = WhisperProcessor::new(whisper_model, device.clone(), progress_manager)?;
Ok(())
}

pub fn load_vad_model(window: tauri::Window, device: candle_core::Device) -> anyhow::Result<()> {
let progress_manager = ModelLoadProgressEmitterManager::new(window);

let whisper_model = WhichWhisperModel::Tiny;

info!("Loading VAD model: {:?}", whisper_model);

let _ = VADProcessor::new(device.clone(), 0.3, progress_manager)?;
Ok(())
}
96 changes: 96 additions & 0 deletions apps/stage-tamagotchi/src-tauri/src/whisper/vad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::collections::HashMap;

use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use candle_onnx::simple_eval;
use hf_hub::{Repo, RepoType};

use crate::whisper::progress;

pub struct VADProcessor {
model: candle_onnx::onnx::ModelProto,
frame_size: usize,
context_size: usize,
sample_rate: Tensor,
state: Tensor,
context: Tensor,
device: Device,
threshold: f32,
}

impl VADProcessor {
pub fn new(
device: Device,
threshold: f32,
manager: progress::ModelLoadProgressEmitterManager,
) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(Repo::with_revision(
"onnx-community/silero-vad".into(),
RepoType::Model,
"main".into(),
));

let model_path = repo.download_with_progress(
"onnx/model.onnx",
manager.clone().new_for("onnx/model.onnx"),
)?;

let model = candle_onnx::read_file(model_path)?;

let sample_rate_value = 16000i64;
let (frame_size, context_size) = (512, 64);

Ok(Self {
model,
frame_size,
context_size,
sample_rate: Tensor::new(sample_rate_value, &device)?,
state: Tensor::zeros((2, 1, 128), DType::F32, &device)?,
context: Tensor::zeros((1, context_size), DType::F32, &device)?,
device,
threshold,
})
}

pub fn process_chunk(
&mut self,
chunk: &[f32],
) -> Result<f32> {
if chunk.len() != self.frame_size {
return Ok(0.0);
}

let next_context = Tensor::from_slice(&chunk[self.frame_size - self.context_size..], (1, self.context_size), &self.device)?;
let chunk_tensor = Tensor::from_vec(chunk.to_vec(), (1, self.frame_size), &self.device)?;

let input = Tensor::cat(&[&self.context, &chunk_tensor], 1)?;
let inputs: HashMap<String, Tensor> = HashMap::from_iter([("input".to_string(), input), ("sr".to_string(), self.sample_rate.clone()), ("state".to_string(), self.state.clone())]);

let outputs = simple_eval(&self.model, inputs)?;
let graph = self.model.graph.as_ref().unwrap();
let out_names = &graph.output;

let output = outputs
.get(&out_names[0].name)
.ok_or_else(|| anyhow::anyhow!("Missing VAD output tensor: {}", &out_names[0].name))?
.clone();

self.state = outputs
.get(&out_names[1].name)
.ok_or_else(|| anyhow::anyhow!("Missing VAD state tensor: {}", &out_names[1].name))?
.clone();

self.context = next_context;

let speech_prob = output.flatten_all()?.to_vec1::<f32>()?[0];
Ok(speech_prob)
}

pub fn is_speech(
&self,
prob: f32,
) -> bool {
prob >= self.threshold
}
}
Empty file.