Avoid the attention mask where possible. (#1933)

This commit is contained in:
Laurent Mazare 2024-03-25 15:31:04 +01:00 committed by GitHub
parent cd254074f3
commit d3a8d291d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 16 deletions

View File

@ -247,7 +247,7 @@ impl FalconAttention {
}
}
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
let fused_qkv = self.query_key_value.forward(x)?;
let head_dim = self.head_dim;
let (query, key, value) = self.split_heads(&fused_qkv)?;
@ -267,7 +267,6 @@ impl FalconAttention {
(query, key)
};
let (mut key, mut value) = (key, value);
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
if self.use_cache {
if let Some((cache_k, cache_v)) = &self.kv_cache {
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
@ -293,13 +292,18 @@ impl FalconAttention {
// Only handle the case where alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
let attention_scores = candle_nn::ops::softmax(
&attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.to_dtype(DType::F32)?,
D::Minus1,
)?
.to_dtype(x.dtype())?;
let attention_scores = match mask {
None => attention_scores,
Some(mask) => {
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
.to_dtype(query.dtype())?;
attention_scores.broadcast_add(&mask.squeeze(1)?)?
}
};
let attention_scores =
candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
.to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
@ -372,7 +376,7 @@ impl FalconDecoderLayer {
})
}
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
let residual = x.clone();
let ln_attn = self.inp_layernorm.forward(x)?;
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
@ -457,9 +461,13 @@ impl Falcon {
Some((k, _)) => k.dim(1)?,
None => 0,
};
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
let causal_mask = if seq_len <= 1 {
None
} else {
Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
};
for block in self.blocks.iter_mut() {
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
}
let hidden_state = self.ln_f.forward(&hidden_state)?;
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;

View File

@ -194,8 +194,12 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = if seq_len <= 1 {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;

View File

@ -71,8 +71,12 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = if seq_len <= 1 {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;