Properly set the is_bf16 flag. (#738)

This commit is contained in:
Laurent Mazare 2023-09-04 17:45:26 +02:00 committed by GitHub
parent f80fd44201
commit ab0d9fbdd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 6 deletions

View File

@ -26,6 +26,7 @@ impl FlashAttn {
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
is_bf16: bool,
) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
let dev = q.device();
@ -94,6 +95,7 @@ impl FlashAttn {
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
@ -132,7 +134,7 @@ impl FlashAttn {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ 0,
/* is_bf16 */ is_bf16,
)
}
@ -168,8 +170,8 @@ impl candle::CustomOp3 for FlashAttn {
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
@ -222,6 +224,7 @@ impl FlashAttnVarLen {
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
is_bf16: bool,
) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
let dev = q.device();
@ -321,6 +324,7 @@ impl FlashAttnVarLen {
.w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
@ -361,7 +365,7 @@ impl FlashAttnVarLen {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ 0,
/* is_bf16 */ is_bf16,
)
}
@ -397,8 +401,8 @@ impl candle::CustomOp3 for FlashAttnVarLen {
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}