Refactor the whisper microphone example. (#2523)
* Refactor the whisper microphone example. * Tweak the whisper microphone example more.
This commit is contained in:
parent
aa35bf2ff5
commit
6110ad8d4f
|
@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
|||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
microphone = ["cpal", "rubato"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
|
|
@ -10,7 +10,6 @@ use candle_nn::{ops::softmax, VarBuilder};
|
|||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use std::iter;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
|
@ -18,7 +17,6 @@ mod multilingual;
|
|||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
|
@ -479,6 +477,10 @@ struct Args {
|
|||
/// Print the full DecodingResult structure rather than just the text.
|
||||
#[arg(long)]
|
||||
verbose: bool,
|
||||
|
||||
/// The input device to use.
|
||||
#[arg(long)]
|
||||
device: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
|
@ -543,13 +545,12 @@ pub fn main() -> Result<()> {
|
|||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
|
||||
};
|
||||
let language_token = None;
|
||||
let mut dc = Decoder::new(
|
||||
let mut decoder = Decoder::new(
|
||||
model,
|
||||
tokenizer.clone(),
|
||||
args.seed,
|
||||
&device,
|
||||
language_token,
|
||||
/* language_token */ None,
|
||||
args.task,
|
||||
args.timestamps,
|
||||
args.verbose,
|
||||
|
@ -565,47 +566,69 @@ pub fn main() -> Result<()> {
|
|||
|
||||
// Set up the input device and stream with the default input config.
|
||||
let host = cpal::default_host();
|
||||
let _device = "default";
|
||||
let _device = if _device == "default" {
|
||||
host.default_input_device()
|
||||
} else {
|
||||
host.input_devices()?
|
||||
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
|
||||
let audio_device = match args.device.as_ref() {
|
||||
None => host.default_input_device(),
|
||||
Some(device) => host
|
||||
.input_devices()?
|
||||
.find(|x| x.name().map_or(false, |y| &y == device)),
|
||||
}
|
||||
.expect("failed to find input device");
|
||||
.expect("failed to find the audio input device");
|
||||
|
||||
let _config = _device
|
||||
let audio_config = audio_device
|
||||
.default_input_config()
|
||||
.expect("Failed to get default input config");
|
||||
println!("audio config {audio_config:?}");
|
||||
|
||||
let channel_count = _config.channels() as usize;
|
||||
|
||||
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
|
||||
let audio_ring_buffer_2 = audio_ring_buffer.clone();
|
||||
|
||||
std::thread::spawn(move || loop {
|
||||
let data = record_audio(&_device, &_config, 300).unwrap();
|
||||
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
|
||||
let max_len = data.len() * 16;
|
||||
let data_len = data.len();
|
||||
let len = audio_ring_buffer.lock().unwrap().len();
|
||||
if len > max_len {
|
||||
let mut data = audio_ring_buffer.lock().unwrap();
|
||||
let new_data = data[data_len..].to_vec();
|
||||
*data = new_data;
|
||||
}
|
||||
});
|
||||
let channel_count = audio_config.channels() as usize;
|
||||
let in_sample_rate = audio_config.sample_rate().0 as usize;
|
||||
let resample_ratio = 16000. / in_sample_rate as f64;
|
||||
let mut resampler = rubato::FastFixedIn::new(
|
||||
resample_ratio,
|
||||
10.,
|
||||
rubato::PolynomialDegree::Septic,
|
||||
1024,
|
||||
1,
|
||||
)?;
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let stream = audio_device.build_input_stream(
|
||||
&audio_config.config(),
|
||||
move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let pcm = pcm
|
||||
.iter()
|
||||
.step_by(channel_count)
|
||||
.copied()
|
||||
.collect::<Vec<f32>>();
|
||||
if !pcm.is_empty() {
|
||||
tx.send(pcm).unwrap()
|
||||
}
|
||||
},
|
||||
move |err| {
|
||||
eprintln!("an error occurred on stream: {}", err);
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
stream.play()?;
|
||||
|
||||
// loop to process the audio data forever (until the user stops the program)
|
||||
println!("Transcribing audio...");
|
||||
for (i, _) in iter::repeat(()).enumerate() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
||||
let data = audio_ring_buffer_2.lock().unwrap().clone();
|
||||
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
println!("transcribing audio...");
|
||||
let mut buffered_pcm = vec![];
|
||||
let mut language_token_set = false;
|
||||
while let Ok(pcm) = rx.recv() {
|
||||
use rubato::Resampler;
|
||||
|
||||
buffered_pcm.extend_from_slice(&pcm);
|
||||
if buffered_pcm.len() < 10 * in_sample_rate {
|
||||
continue;
|
||||
}
|
||||
let mut resampled_pcm = vec![];
|
||||
for buffered_pcm in buffered_pcm.chunks(1024) {
|
||||
let pcm = resampler.process(&[&buffered_pcm], None)?;
|
||||
resampled_pcm.extend_from_slice(&pcm[0])
|
||||
}
|
||||
let pcm = resampled_pcm;
|
||||
println!("{} {}", buffered_pcm.len(), pcm.len());
|
||||
buffered_pcm.clear();
|
||||
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(
|
||||
mel,
|
||||
|
@ -614,9 +637,13 @@ pub fn main() -> Result<()> {
|
|||
)?;
|
||||
|
||||
// on the first iteration, we detect the language and set the language token.
|
||||
if i == 0 {
|
||||
if !language_token_set {
|
||||
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
|
||||
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
|
||||
(true, None) => Some(multilingual::detect_language(
|
||||
decoder.model(),
|
||||
&tokenizer,
|
||||
&mel,
|
||||
)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
|
@ -627,47 +654,12 @@ pub fn main() -> Result<()> {
|
|||
}
|
||||
};
|
||||
println!("language_token: {:?}", language_token);
|
||||
dc.set_language_token(language_token);
|
||||
decoder.set_language_token(language_token);
|
||||
language_token_set = true;
|
||||
}
|
||||
dc.run(
|
||||
&mel,
|
||||
Some((
|
||||
i as f64,
|
||||
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
|
||||
)),
|
||||
)?;
|
||||
dc.reset_kv_cache();
|
||||
decoder.run(&mel, None)?;
|
||||
decoder.reset_kv_cache();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn record_audio(
|
||||
device: &cpal::Device,
|
||||
config: &cpal::SupportedStreamConfig,
|
||||
milliseconds: u64,
|
||||
) -> Result<Vec<i16>> {
|
||||
let writer = Arc::new(Mutex::new(Vec::new()));
|
||||
let writer_2 = writer.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config.config(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let processed = data
|
||||
.iter()
|
||||
.map(|v| (v * 32768.0) as i16)
|
||||
.collect::<Vec<i16>>();
|
||||
writer_2.lock().unwrap().extend_from_slice(&processed);
|
||||
},
|
||||
move |err| {
|
||||
eprintln!("an error occurred on stream: {}", err);
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
stream.play()?;
|
||||
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
|
||||
drop(stream);
|
||||
let data = writer.lock().unwrap().clone();
|
||||
let step = 3;
|
||||
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
|
||||
Ok(data)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue