Embed the mel filters in the whisper binary. (#373)

This commit is contained in:
Laurent Mazare 2023-08-09 19:27:26 +02:00 committed by GitHub
parent 5b79b38bc7
commit a3b1699409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 13 deletions

View File

@ -10,7 +10,7 @@
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -243,13 +243,6 @@ struct Args {
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The mel filters in safetensors format.
#[arg(
long,
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
)]
filters: String,
}
fn main() -> Result<()> {
@ -301,11 +294,9 @@ fn main() -> Result<()> {
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mel_bytes = include_bytes!("melfilters.bytes");
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;