From 60cd1551ca29b2e3049f18ec8e60b6f165cfe941 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 12 Aug 2023 22:17:08 +0200 Subject: [PATCH] Add a KV cache to whisper. (#426) --- candle-examples/examples/whisper/main.rs | 10 +-- candle-examples/examples/whisper/model.rs | 71 ++++++++++++++----- .../examples/whisper/multilingual.rs | 6 +- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 1c24de60..7e614c9c 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -109,8 +109,8 @@ impl Decoder { } fn decode(&mut self, mel: &Tensor, t: f64) -> Result { - let model = &self.model; - let audio_features = model.encoder.forward(mel)?; + let model = &mut self.model; + let audio_features = model.encoder.forward(mel, true)?; println!("audio features: {:?}", audio_features.dims()); let sample_len = model.config.max_target_positions / 2; let mut sum_logprob = 0f64; @@ -126,7 +126,7 @@ impl Decoder { // The model expects a batch dim but this inference loop does not handle // it so we add it at this point. let tokens_t = tokens_t.unsqueeze(0)?; - let logits = model.decoder.forward(&tokens_t, &audio_features)?; + let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; let logits = logits.squeeze(0)?; // Extract the no speech probability on the first iteration by looking at the first @@ -393,10 +393,10 @@ fn main() -> Result<()> { let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; - let model = Whisper::load(&vb, config)?; + let mut model = Whisper::load(&vb, config)?; let language_token = match (args.model.is_multilingual(), args.language) { - (true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?), + (true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?), (false, None) => None, (true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) { Ok(token_id) => Some(token_id), diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index c61882bc..d3ebe02e 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -105,6 +105,7 @@ struct MultiHeadAttention { out: Linear, n_head: usize, span: tracing::Span, + kv_cache: Option<(Tensor, Tensor)>, } impl MultiHeadAttention { @@ -121,14 +122,39 @@ impl MultiHeadAttention { out, n_head, span, + kv_cache: None, }) } - fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result { let _enter = self.span.enter(); let q = self.query.forward(x)?; - let k = self.key.forward(xa.unwrap_or(x))?; - let v = self.value.forward(xa.unwrap_or(x))?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; let wv = self.qkv_attention(&q, &k, &v, mask)?; let out = self.out.forward(&wv)?; Ok(out) @@ -201,12 +227,20 @@ impl ResidualAttentionBlock { }) } - fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result { let _enter = self.span.enter(); - let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; let mut x = (x + attn)?; - if let Some((attn, ln)) = &self.cross_attn { - x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; } let mlp = self.mlp_linear2.forward( &self @@ -283,7 +317,7 @@ impl AudioEncoder { }) } - pub fn forward(&self, x: &Tensor) -> Result { + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result { let _enter = self.span.enter(); let x = { let _enter = self.conv1_span.enter(); @@ -297,8 +331,8 @@ impl AudioEncoder { let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter() { - x = block.forward(&x, None, None)? + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? } let x = self.ln_post.forward(&x)?; Ok(x) @@ -344,15 +378,15 @@ impl TextDecoder { }) } - pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result { let _enter = self.span.enter(); let x_dims = x.dims(); let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter() { - x = block.forward(&x, Some(xa), Some(&self.mask))?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; } let x = self.ln.forward(&x)?; let w = self @@ -383,9 +417,14 @@ impl Whisper { } #[allow(dead_code)] - pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result { - let enc = self.encoder.forward(mel)?; - let dec = self.decoder.forward(tokens, &enc)?; + pub fn forward( + &mut self, + mel: &Tensor, + tokens: &Tensor, + flush_kv_cache: bool, + ) -> Result { + let enc = self.encoder.forward(mel, flush_kv_cache)?; + let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?; Ok(dec) } } diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index 1342ad55..01722a68 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -105,19 +105,19 @@ const LANGUAGES: [(&str, &str); 99] = [ ]; /// Returns the token id for the selected language. -pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result { +pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result { let device = mel.device(); let language_token_ids = LANGUAGES .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::>>()?; let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; - let audio_features = model.encoder.forward(mel)?; + let audio_features = model.encoder.forward(mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; let logits = model .decoder - .forward(&tokens, &audio_features)? + .forward(&tokens, &audio_features, true)? .i(0)? .i(0)?; let logits = logits.index_select(&language_token_ids, 0)?;