Make the cache for the llama model explicit too. (#1745)

This commit is contained in:
Laurent Mazare 2024-02-22 12:04:33 +01:00 committed by GitHub
parent 544018b6d0
commit 28057781aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 35 deletions

View File

@ -120,7 +120,7 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, cache) = {
let (llama, tokenizer_filename, mut cache) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
@ -146,7 +146,7 @@ fn main() -> Result<()> {
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &config)?, tokenizer_filename, cache)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
@ -172,7 +172,7 @@ fn main() -> Result<()> {
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits

View File

@ -2,7 +2,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
@ -84,10 +83,9 @@ impl Config {
#[derive(Debug, Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
masks: HashMap<usize, Tensor>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
kvs: Vec<Option<(Tensor, Tensor)>>,
cos: Tensor,
sin: Tensor,
device: Device,
@ -112,25 +110,24 @@ impl Cache {
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
masks: HashMap::new(),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
kvs: vec![None; config.num_hidden_layers],
device: device.clone(),
cos,
sin,
})
}
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
@ -164,7 +161,6 @@ struct CausalSelfAttention {
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
@ -187,11 +183,11 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
@ -201,7 +197,13 @@ impl CausalSelfAttention {
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let q = self.q_proj.forward(x)?;
@ -218,12 +220,11 @@ impl CausalSelfAttention {
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
@ -239,7 +240,7 @@ impl CausalSelfAttention {
.contiguous()?
}
}
cache[block_idx] = Some((k.clone(), v.clone()))
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
@ -258,7 +259,7 @@ impl CausalSelfAttention {
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -283,7 +284,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
@ -301,7 +302,6 @@ impl CausalSelfAttention {
num_attention_heads: cfg.num_attention_heads,
num_key_value_heads: cfg.num_key_value_heads,
head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
@ -357,19 +357,25 @@ struct Block {
}
impl Block {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::load(
@ -396,11 +402,11 @@ pub struct Llama {
}
impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
@ -408,12 +414,12 @@ impl Llama {
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
.collect();
Ok(Self {