Add a KV cache to whisper. (#426)
This commit is contained in:
parent
a0908d212c
commit
60cd1551ca
|
@ -109,8 +109,8 @@ impl Decoder {
|
|||
}
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
|
@ -126,7 +126,7 @@ impl Decoder {
|
|||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
|
@ -393,10 +393,10 @@ fn main() -> Result<()> {
|
|||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let model = Whisper::load(&vb, config)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
|
||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||
(true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?),
|
||||
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
|
|
|
@ -105,6 +105,7 @@ struct MultiHeadAttention {
|
|||
out: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
|
@ -121,14 +122,39 @@ impl MultiHeadAttention {
|
|||
out,
|
||||
n_head,
|
||||
span,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
let (k, v) = match xa {
|
||||
None => {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
(k, v)
|
||||
}
|
||||
Some(x) => {
|
||||
if flush_cache {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
if let Some((k, v)) = &self.kv_cache {
|
||||
(k.clone(), v.clone())
|
||||
} else {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
|
@ -201,12 +227,20 @@ impl ResidualAttentionBlock {
|
|||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||
let attn = self
|
||||
.attn
|
||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
|
@ -283,7 +317,7 @@ impl AudioEncoder {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = {
|
||||
let _enter = self.conv1_span.enter();
|
||||
|
@ -297,8 +331,8 @@ impl AudioEncoder {
|
|||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, None, None)?
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
|
@ -344,15 +378,15 @@ impl TextDecoder {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self
|
||||
|
@ -383,9 +417,14 @@ impl Whisper {
|
|||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel)?;
|
||||
let dec = self.decoder.forward(tokens, &enc)?;
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
mel: &Tensor,
|
||||
tokens: &Tensor,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel, flush_kv_cache)?;
|
||||
let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?;
|
||||
Ok(dec)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,19 +105,19 @@ const LANGUAGES: [(&str, &str); 99] = [
|
|||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.forward(&tokens, &audio_features)?
|
||||
.forward(&tokens, &audio_features, true)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
|
|
Loading…
Reference in New Issue