Make the cache for the llama model explicit too. (#1745)
This commit is contained in:
parent
544018b6d0
commit
28057781aa
|
@ -120,7 +120,7 @@ fn main() -> Result<()> {
|
||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => DType::F16,
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, cache) = {
|
let (llama, tokenizer_filename, mut cache) = {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
|
@ -146,7 +146,7 @@ fn main() -> Result<()> {
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||||
|
@ -172,7 +172,7 @@ fn main() -> Result<()> {
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, context_index)?;
|
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
|
|
|
@ -2,7 +2,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
|
@ -84,10 +83,9 @@ impl Config {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: HashMap<usize, Tensor>,
|
||||||
pub use_kv_cache: bool,
|
pub use_kv_cache: bool,
|
||||||
#[allow(clippy::type_complexity)]
|
kvs: Vec<Option<(Tensor, Tensor)>>,
|
||||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
@ -112,25 +110,24 @@ impl Cache {
|
||||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
masks: HashMap::new(),
|
||||||
use_kv_cache,
|
use_kv_cache,
|
||||||
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
kvs: vec![None; config.num_hidden_layers],
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||||
let mut masks = self.masks.lock().unwrap();
|
if let Some(mask) = self.masks.get(&t) {
|
||||||
if let Some(mask) = masks.get(&t) {
|
|
||||||
Ok(mask.clone())
|
Ok(mask.clone())
|
||||||
} else {
|
} else {
|
||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||||
masks.insert(t, mask.clone());
|
self.masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -164,7 +161,6 @@ struct CausalSelfAttention {
|
||||||
num_attention_heads: usize,
|
num_attention_heads: usize,
|
||||||
num_key_value_heads: usize,
|
num_key_value_heads: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
span_rot: tracing::Span,
|
span_rot: tracing::Span,
|
||||||
|
@ -187,11 +183,11 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||||
let _enter = self.span_rot.enter();
|
let _enter = self.span_rot.enter();
|
||||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||||
|
@ -201,7 +197,13 @@ impl CausalSelfAttention {
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
block_idx: usize,
|
||||||
|
cache: &mut Cache,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||||
let q = self.q_proj.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
|
@ -218,12 +220,11 @@ impl CausalSelfAttention {
|
||||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if cache.use_kv_cache {
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
|
||||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||||
let k_seq_len = k.dims()[1];
|
let k_seq_len = k.dims()[1];
|
||||||
|
@ -239,7 +240,7 @@ impl CausalSelfAttention {
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
let k = self.repeat_kv(k)?;
|
let k = self.repeat_kv(k)?;
|
||||||
|
@ -258,7 +259,7 @@ impl CausalSelfAttention {
|
||||||
let k = k.to_dtype(DType::F32)?;
|
let k = k.to_dtype(DType::F32)?;
|
||||||
let v = v.to_dtype(DType::F32)?;
|
let v = v.to_dtype(DType::F32)?;
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
|
@ -283,7 +284,7 @@ impl CausalSelfAttention {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
let size_in = cfg.hidden_size;
|
let size_in = cfg.hidden_size;
|
||||||
|
@ -301,7 +302,6 @@ impl CausalSelfAttention {
|
||||||
num_attention_heads: cfg.num_attention_heads,
|
num_attention_heads: cfg.num_attention_heads,
|
||||||
num_key_value_heads: cfg.num_key_value_heads,
|
num_key_value_heads: cfg.num_key_value_heads,
|
||||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||||
cache: cache.clone(),
|
|
||||||
use_flash_attn: cfg.use_flash_attn,
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
span,
|
span,
|
||||||
span_rot,
|
span_rot,
|
||||||
|
@ -357,19 +357,25 @@ struct Block {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
block_idx: usize,
|
||||||
|
cache: &mut Cache,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
let rms_2 = RmsNorm::load(
|
let rms_2 = RmsNorm::load(
|
||||||
|
@ -396,11 +402,11 @@ pub struct Llama {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Llama {
|
impl Llama {
|
||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, index_pos, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let x = x.i((.., seq_len - 1, ..))?;
|
let x = x.i((.., seq_len - 1, ..))?;
|
||||||
|
@ -408,12 +414,12 @@ impl Llama {
|
||||||
logits.to_dtype(DType::F32)
|
logits.to_dtype(DType::F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|
Loading…
Reference in New Issue