From e3c146ada665cd9ba5265a742c502a7309ca879e Mon Sep 17 00:00:00 2001 From: shua Date: Thu, 22 Aug 2024 22:50:42 +0200 Subject: [PATCH] silero-vad v5 example (#2321) * silero-vad v5 example This change adds an example of how to run silero-vad v5 * PR: rename 'vad' to 'silero-vad' * Update README.md --------- Co-authored-by: Laurent Mazare --- candle-examples/Cargo.toml | 4 + candle-examples/examples/silero-vad/README.md | 12 ++ candle-examples/examples/silero-vad/main.rs | 200 ++++++++++++++++++ 3 files changed, 216 insertions(+) create mode 100644 candle-examples/examples/silero-vad/README.md create mode 100644 candle-examples/examples/silero-vad/main.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 56e3d535..6879c48b 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -108,3 +108,7 @@ required-features = ["encodec"] [[example]] name = "depth_anything_v2" required-features = ["depth_anything_v2"] + +[[example]] +name = "silero-vad" +required-features = ["onnx"] diff --git a/candle-examples/examples/silero-vad/README.md b/candle-examples/examples/silero-vad/README.md new file mode 100644 index 00000000..14dd8a82 --- /dev/null +++ b/candle-examples/examples/silero-vad/README.md @@ -0,0 +1,12 @@ +# silero-vad: Voice Activity Detection + +[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio. + +This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad). + +## Running the example + +```bash +$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000 +``` + diff --git a/candle-examples/examples/silero-vad/main.rs b/candle-examples/examples/silero-vad/main.rs new file mode 100644 index 00000000..4618ad80 --- /dev/null +++ b/candle-examples/examples/silero-vad/main.rs @@ -0,0 +1,200 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use clap::Parser; + +use candle::{DType, Tensor}; +use candle_onnx; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "silero")] + Silero, +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum SampleRate { + #[value(name = "8000")] + Sr8k, + #[value(name = "16000")] + Sr16k, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + input: Option, + + #[arg(long)] + sample_rate: SampleRate, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + config_file: Option, + + /// The model to use. + #[arg(long, default_value = "silero")] + which: Which, +} + +/// an iterator which reads consecutive frames of le i16 values from a reader +struct I16Frames { + rdr: R, + buf: Box<[u8]>, + len: usize, + eof: bool, +} +impl I16Frames { + fn new(rdr: R, frame_size: usize) -> Self { + I16Frames { + rdr, + buf: vec![0; frame_size * std::mem::size_of::()].into_boxed_slice(), + len: 0, + eof: false, + } + } +} +impl Iterator for I16Frames { + type Item = std::io::Result>; + + fn next(&mut self) -> Option { + if self.eof { + return None; + } + self.len += match self.rdr.read(&mut self.buf[self.len..]) { + Ok(0) => { + self.eof = true; + 0 + } + Ok(n) => n, + Err(e) => return Some(Err(e)), + }; + if self.eof || self.len == self.buf.len() { + let buf = self.buf[..self.len] + .chunks(2) + .map(|bs| match bs { + [a, b] => i16::from_le_bytes([*a, *b]), + _ => unreachable!(), + }) + .map(|i| i as f32 / i16::MAX as f32) + .collect(); + self.len = 0; + Some(Ok(buf)) + } else { + self.next() + } + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let start = std::time::Instant::now(); + let model_id = match &args.model_id { + Some(model_id) => std::path::PathBuf::from(model_id), + None => match args.which { + Which::Silero => hf_hub::api::sync::Api::new()? + .model("onnx-community/silero-vad".into()) + .get("onnx/model.onnx")?, + // TODO: candle-onnx doesn't support Int8 dtype + // Which::SileroQuantized => hf_hub::api::sync::Api::new()? + // .model("onnx-community/silero-vad".into()) + // .get("onnx/model_quantized.onnx")?, + }, + }; + let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate { + SampleRate::Sr8k => (8000, 256, 32), + SampleRate::Sr16k => (16000, 512, 64), + }; + println!("retrieved the files in {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let model = candle_onnx::read_file(model_id)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + struct State { + frame_size: usize, + sample_rate: Tensor, + state: Tensor, + context: Tensor, + } + + let mut state = State { + frame_size, + sample_rate: Tensor::new(sample_rate, &device)?, + state: Tensor::zeros((2, 1, 128), DType::F32, &device)?, + context: Tensor::zeros((1, context_size), DType::F32, &device)?, + }; + let mut res = vec![]; + for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) { + let chunk = chunk.unwrap(); + if chunk.len() < state.frame_size { + continue; + } + let next_context = Tensor::from_slice( + &chunk[state.frame_size - context_size..], + (1, context_size), + &device, + )?; + let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?; + let chunk = Tensor::cat(&[&state.context, &chunk], 1)?; + let inputs = std::collections::HashMap::from_iter([ + ("input".to_string(), chunk), + ("sr".to_string(), state.sample_rate.clone()), + ("state".to_string(), state.state.clone()), + ]); + let out = candle_onnx::simple_eval(&model, inputs).unwrap(); + let out_names = &model.graph.as_ref().unwrap().output; + let output = out.get(&out_names[0].name).unwrap().clone(); + state.state = out.get(&out_names[1].name).unwrap().clone(); + assert_eq!(state.state.dims(), &[2, 1, 128]); + state.context = next_context; + + let output = output.flatten_all()?.to_vec1::()?; + assert_eq!(output.len(), 1); + let output = output[0]; + println!("vad chunk prediction: {output}"); + res.push(output); + } + println!("calculated prediction in {:?}", start.elapsed()); + + let res_len = res.len() as f32; + let prediction = res.iter().sum::() / res_len; + println!("vad average prediction: {prediction}"); + Ok(()) +}