Properly set the is_bf16 flag. (#738)
This commit is contained in:
parent
f80fd44201
commit
ab0d9fbdd1
|
@ -26,6 +26,7 @@ impl FlashAttn {
|
|||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
is_bf16: bool,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
|
||||
let dev = q.device();
|
||||
|
@ -94,6 +95,7 @@ impl FlashAttn {
|
|||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
|
@ -132,7 +134,7 @@ impl FlashAttn {
|
|||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ 0,
|
||||
/* is_bf16 */ is_bf16,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -168,8 +170,8 @@ impl candle::CustomOp3 for FlashAttn {
|
|||
v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
match q.dtype() {
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
|
||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||
}
|
||||
}
|
||||
|
@ -222,6 +224,7 @@ impl FlashAttnVarLen {
|
|||
k_l: &Layout,
|
||||
v: &candle::CudaStorage,
|
||||
v_l: &Layout,
|
||||
is_bf16: bool,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
|
||||
let dev = q.device();
|
||||
|
@ -321,6 +324,7 @@ impl FlashAttnVarLen {
|
|||
.w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
|
@ -361,7 +365,7 @@ impl FlashAttnVarLen {
|
|||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ 0,
|
||||
/* is_bf16 */ is_bf16,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -397,8 +401,8 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
|||
v_l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
match q.dtype() {
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
|
||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
|
||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue