silero-vad v5 example (#2321)
* silero-vad v5 example This change adds an example of how to run silero-vad v5 * PR: rename 'vad' to 'silero-vad' * Update README.md --------- Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
This commit is contained in:
parent
1e96b8b695
commit
e3c146ada6
|
@ -108,3 +108,7 @@ required-features = ["encodec"]
|
|||
[[example]]
|
||||
name = "depth_anything_v2"
|
||||
required-features = ["depth_anything_v2"]
|
||||
|
||||
[[example]]
|
||||
name = "silero-vad"
|
||||
required-features = ["onnx"]
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# silero-vad: Voice Activity Detection
|
||||
|
||||
[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio.
|
||||
|
||||
This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad).
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||
```
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_onnx;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "silero")]
|
||||
Silero,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum SampleRate {
|
||||
#[value(name = "8000")]
|
||||
Sr8k,
|
||||
#[value(name = "16000")]
|
||||
Sr16k,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
input: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
sample_rate: SampleRate,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "silero")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
/// an iterator which reads consecutive frames of le i16 values from a reader
|
||||
struct I16Frames<R> {
|
||||
rdr: R,
|
||||
buf: Box<[u8]>,
|
||||
len: usize,
|
||||
eof: bool,
|
||||
}
|
||||
impl<R> I16Frames<R> {
|
||||
fn new(rdr: R, frame_size: usize) -> Self {
|
||||
I16Frames {
|
||||
rdr,
|
||||
buf: vec![0; frame_size * std::mem::size_of::<i16>()].into_boxed_slice(),
|
||||
len: 0,
|
||||
eof: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<R: std::io::Read> Iterator for I16Frames<R> {
|
||||
type Item = std::io::Result<Vec<f32>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.eof {
|
||||
return None;
|
||||
}
|
||||
self.len += match self.rdr.read(&mut self.buf[self.len..]) {
|
||||
Ok(0) => {
|
||||
self.eof = true;
|
||||
0
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
if self.eof || self.len == self.buf.len() {
|
||||
let buf = self.buf[..self.len]
|
||||
.chunks(2)
|
||||
.map(|bs| match bs {
|
||||
[a, b] => i16::from_le_bytes([*a, *b]),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.map(|i| i as f32 / i16::MAX as f32)
|
||||
.collect();
|
||||
self.len = 0;
|
||||
Some(Ok(buf))
|
||||
} else {
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => std::path::PathBuf::from(model_id),
|
||||
None => match args.which {
|
||||
Which::Silero => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx-community/silero-vad".into())
|
||||
.get("onnx/model.onnx")?,
|
||||
// TODO: candle-onnx doesn't support Int8 dtype
|
||||
// Which::SileroQuantized => hf_hub::api::sync::Api::new()?
|
||||
// .model("onnx-community/silero-vad".into())
|
||||
// .get("onnx/model_quantized.onnx")?,
|
||||
},
|
||||
};
|
||||
let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate {
|
||||
SampleRate::Sr8k => (8000, 256, 32),
|
||||
SampleRate::Sr16k => (16000, 512, 64),
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = candle_onnx::read_file(model_id)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
struct State {
|
||||
frame_size: usize,
|
||||
sample_rate: Tensor,
|
||||
state: Tensor,
|
||||
context: Tensor,
|
||||
}
|
||||
|
||||
let mut state = State {
|
||||
frame_size,
|
||||
sample_rate: Tensor::new(sample_rate, &device)?,
|
||||
state: Tensor::zeros((2, 1, 128), DType::F32, &device)?,
|
||||
context: Tensor::zeros((1, context_size), DType::F32, &device)?,
|
||||
};
|
||||
let mut res = vec![];
|
||||
for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) {
|
||||
let chunk = chunk.unwrap();
|
||||
if chunk.len() < state.frame_size {
|
||||
continue;
|
||||
}
|
||||
let next_context = Tensor::from_slice(
|
||||
&chunk[state.frame_size - context_size..],
|
||||
(1, context_size),
|
||||
&device,
|
||||
)?;
|
||||
let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?;
|
||||
let chunk = Tensor::cat(&[&state.context, &chunk], 1)?;
|
||||
let inputs = std::collections::HashMap::from_iter([
|
||||
("input".to_string(), chunk),
|
||||
("sr".to_string(), state.sample_rate.clone()),
|
||||
("state".to_string(), state.state.clone()),
|
||||
]);
|
||||
let out = candle_onnx::simple_eval(&model, inputs).unwrap();
|
||||
let out_names = &model.graph.as_ref().unwrap().output;
|
||||
let output = out.get(&out_names[0].name).unwrap().clone();
|
||||
state.state = out.get(&out_names[1].name).unwrap().clone();
|
||||
assert_eq!(state.state.dims(), &[2, 1, 128]);
|
||||
state.context = next_context;
|
||||
|
||||
let output = output.flatten_all()?.to_vec1::<f32>()?;
|
||||
assert_eq!(output.len(), 1);
|
||||
let output = output[0];
|
||||
println!("vad chunk prediction: {output}");
|
||||
res.push(output);
|
||||
}
|
||||
println!("calculated prediction in {:?}", start.elapsed());
|
||||
|
||||
let res_len = res.len() as f32;
|
||||
let prediction = res.iter().sum::<f32>() / res_len;
|
||||
println!("vad average prediction: {prediction}");
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue