use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572)

This commit is contained in:
Zack Angelo 2024-10-23 11:07:09 -07:00 committed by GitHub
parent 7c09215ef4
commit a2e9d41b20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -341,7 +341,8 @@ impl CausalSelfAttention {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};