Avoid the attention mask where possible. (#1933)
This commit is contained in:
parent
cd254074f3
commit
d3a8d291d5
|
@ -247,7 +247,7 @@ impl FalconAttention {
|
|||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
|
@ -267,7 +267,6 @@ impl FalconAttention {
|
|||
(query, key)
|
||||
};
|
||||
let (mut key, mut value) = (key, value);
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
|
||||
if self.use_cache {
|
||||
if let Some((cache_k, cache_v)) = &self.kv_cache {
|
||||
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
|
||||
|
@ -293,13 +292,18 @@ impl FalconAttention {
|
|||
|
||||
// Only handle the case where alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = candle_nn::ops::softmax(
|
||||
&attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.to_dtype(DType::F32)?,
|
||||
D::Minus1,
|
||||
)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attention_scores = match mask {
|
||||
None => attention_scores,
|
||||
Some(mask) => {
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
|
||||
.to_dtype(query.dtype())?;
|
||||
attention_scores.broadcast_add(&mask.squeeze(1)?)?
|
||||
}
|
||||
};
|
||||
|
||||
let attention_scores =
|
||||
candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
||||
|
@ -372,7 +376,7 @@ impl FalconDecoderLayer {
|
|||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
|
||||
let residual = x.clone();
|
||||
let ln_attn = self.inp_layernorm.forward(x)?;
|
||||
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
|
||||
|
@ -457,9 +461,13 @@ impl Falcon {
|
|||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
|
||||
let causal_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
|
||||
}
|
||||
let hidden_state = self.ln_f.forward(&hidden_state)?;
|
||||
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
|
||||
|
|
|
@ -194,8 +194,12 @@ impl CausalSelfAttention {
|
|||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = if seq_len <= 1 {
|
||||
att
|
||||
} else {
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
|
|
|
@ -71,8 +71,12 @@ impl CausalSelfAttention {
|
|||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = if seq_len <= 1 {
|
||||
att
|
||||
} else {
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
|
|
Loading…
Reference in New Issue