From 45e235a7473d473df5c1e50f55504a97e28be822 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 23 May 2024 17:07:21 +0200 Subject: [PATCH] Simplify the KvCache api. (#2207) --- .../examples/quantized-phi/main.rs | 1 - candle-nn/src/kv_cache.rs | 89 +++++++++++-------- .../src/models/quantized_phi3.rs | 8 +- 3 files changed, 54 insertions(+), 44 deletions(-) diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index e046fadb..f567ce2d 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -217,7 +217,6 @@ fn main() -> anyhow::Result<()> { match args.which { Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), Which::Phi3 => Model::Phi3(Phi3::from_gguf( - 1, args.use_flash_attn, model, &mut file, diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 10e9fe5a..eb5dbfdb 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -1,30 +1,25 @@ -use candle::{DType, Device, Result, Shape, Tensor}; +use candle::{Result, Tensor}; #[derive(Debug, Clone)] pub struct Cache { - all_data: Tensor, + // all_data is an option on a Tensor, this makes it possible to only create the actual tensor + // on the first call where the batch size is easily known. + // Also this makes it safe to clone a KvCache that has been reseted (as in it will not share + // its internal state with the cloned instance). + all_data: Option, dim: usize, current_seq_len: usize, max_seq_len: usize, } impl Cache { - pub fn new, D: candle::shape::Dim>( - dim: D, - shape: S, - dtype: DType, - dev: &Device, - ) -> Result { - let shape = shape.into(); - let dim = dim.to_index(&shape, "kv-cache")?; - let max_seq_len = shape.dims()[dim]; - let all_data = Tensor::zeros(shape, dtype, dev)?; - Ok(Self { - all_data, + pub fn new(dim: usize, max_seq_len: usize) -> Self { + Self { + all_data: None, dim, current_seq_len: 0, max_seq_len, - }) + } } pub fn dim(&self) -> usize { @@ -39,20 +34,34 @@ impl Cache { self.max_seq_len } - pub fn all_data(&self) -> &Tensor { + pub fn all_data(&self) -> &Option { &self.all_data } - pub fn current_data(&self) -> Result { - self.all_data.narrow(self.dim, 0, self.current_seq_len) + pub fn current_data(&self) -> Result> { + let data = match self.all_data.as_ref() { + None => None, + Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?), + }; + Ok(data) } pub fn reset(&mut self) { - self.current_seq_len = 0 + self.current_seq_len = 0; + self.all_data = None; } pub fn append(&mut self, src: &Tensor) -> Result<()> { let seq_len = src.dim(self.dim)?; + // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use + // self.all_data.get_or_insert_with. + if self.all_data.is_none() { + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.max_seq_len; + let ad = Tensor::zeros(shape, src.dtype(), src.device())?; + self.all_data = Some(ad) + }; + let ad = self.all_data.as_mut().unwrap(); if self.current_seq_len + seq_len > self.max_seq_len { candle::bail!( "kv-cache: above max-seq-len {}+{seq_len}>{}", @@ -60,8 +69,7 @@ impl Cache { self.max_seq_len ) } - self.all_data - .slice_set(src, self.dim, self.current_seq_len)?; + ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len; Ok(()) } @@ -74,17 +82,10 @@ pub struct KvCache { } impl KvCache { - pub fn new, D: candle::shape::Dim>( - dim: D, - shape: S, - dtype: DType, - dev: &Device, - ) -> Result { - let shape = shape.into(); - let dim = dim.to_index(&shape, "kv-cache")?; - let k = Cache::new(dim, &shape, dtype, dev)?; - let v = Cache::new(dim, &shape, dtype, dev)?; - Ok(Self { k, v }) + pub fn new(dim: usize, max_seq_len: usize) -> Self { + let k = Cache::new(dim, max_seq_len); + let v = Cache::new(dim, max_seq_len); + Self { k, v } } pub fn k_cache(&self) -> &Cache { @@ -103,19 +104,35 @@ impl KvCache { &mut self.v } - pub fn k(&self) -> Result { + pub fn k(&self) -> Result> { self.k.current_data() } - pub fn v(&self) -> Result { + pub fn v(&self) -> Result> { self.v.current_data() } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { self.k.append(k)?; self.v.append(v)?; - let k = self.k.current_data()?; - let v = self.v.current_data()?; + let out_k = self.k.current_data()?; + let out_v = self.v.current_data()?; + let k = match out_k { + None => { + let mut shape = k.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, k.dtype(), k.device())? + } + Some(k) => k, + }; + let v = match out_v { + None => { + let mut shape = v.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, v.dtype(), v.device())? + } + Some(v) => v, + }; Ok((k, v)) } diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index a1161722..f9b55d9d 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -203,7 +203,6 @@ fn precomput_freqs_cis( impl ModelWeights { pub fn from_gguf( - batch_size: usize, use_flash_attn: bool, ct: gguf_file::Content, reader: &mut R, @@ -252,12 +251,7 @@ impl ModelWeights { )?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let kv_cache = KvCache::new( - 2, - (batch_size, head_count_kv, max_seq_len, head_dim), - DType::F32, - device, - )?; + let kv_cache = KvCache::new(2, max_seq_len); layers.push(LayerWeights { attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,