Embed the mel filters in the whisper binary. (#373)
This commit is contained in:
parent
5b79b38bc7
commit
a3b1699409
|
@ -10,7 +10,7 @@
|
|||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
@ -243,13 +243,6 @@ struct Args {
|
|||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The mel filters in safetensors format.
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
|
||||
)]
|
||||
filters: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
|
@ -301,11 +294,9 @@ fn main() -> Result<()> {
|
|||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||
let mel_filters = mel_filters.deserialize()?;
|
||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
let mel_bytes = include_bytes!("melfilters.bytes");
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
|
|
Loading…
Reference in New Issue