feat: add clear_kv_cache to mistral and qmistral models (#1464)
This commit is contained in:
parent
563a79afa1
commit
f6408a3779
|
@ -297,6 +297,10 @@ impl Attention {
|
|||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -340,6 +344,10 @@ impl DecoderLayer {
|
|||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -423,4 +431,10 @@ impl Model {
|
|||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -198,6 +198,10 @@ impl Attention {
|
|||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -241,6 +245,10 @@ impl DecoderLayer {
|
|||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -322,4 +330,10 @@ impl Model {
|
|||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue