Add flash attention (#241)

* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
This commit is contained in:
Laurent Mazare 2023-07-26 07:48:10 +01:00 committed by GitHub
parent c97d51243c
commit d9f9c859af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 2699 additions and 9 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "candle-examples/examples/flash-attn/cutlass"]
path = candle-flash-attn/cutlass
url = https://github.com/NVIDIA/cutlass.git

View File

@ -9,6 +9,7 @@ members = [
"candle-wasm-examples/whisper",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
]

View File

@ -65,7 +65,7 @@ pub use dtype::{DType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use op::CustomOp1;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};

View File

@ -14,6 +14,7 @@ readme = "README.md"
candle = { path = "../candle-core" }
candle-nn = { path = "../candle-nn" }
candle-transformers = { path = "../candle-transformers" }
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
@ -37,4 +38,5 @@ anyhow = { workspace = true }
[features]
default = []
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -6,11 +6,13 @@ use std::path::PathBuf;
struct KernelDirectories {
kernel_dir: &'static str,
rust_target: &'static str,
include_dirs: &'static [&'static str],
}
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/",
rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[],
}];
impl KernelDirectories {
@ -32,12 +34,15 @@ impl KernelDirectories {
{
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
let output = command
.spawn()
@ -221,6 +226,7 @@ fn compute_cap() -> Result<usize> {
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()

View File

@ -116,6 +116,9 @@ struct Args {
#[arg(long)]
v2: bool,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
@ -124,7 +127,7 @@ fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b();
let config = Config::config_7b(args.use_flash_attn);
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let (llama, tokenizer_filename) = match args.npy {

View File

@ -13,10 +13,11 @@ pub struct Config {
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
pub use_flash_attn: bool,
}
impl Config {
pub fn config_7b() -> Self {
pub fn config_7b(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
@ -25,6 +26,7 @@ impl Config {
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
}
}
}
@ -140,6 +142,17 @@ struct CausalSelfAttention {
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
}
#[cfg(feature = "flash-attn")]
fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
impl CausalSelfAttention {
@ -202,12 +215,17 @@ impl CausalSelfAttention {
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = if self.use_flash_attn {
flash_attn(&q, &k, &v)?
} else {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y.to_dtype(x_dtype)?;
let y = self.o_proj.forward(&y)?;
@ -245,6 +263,7 @@ impl CausalSelfAttention {
n_key_value_head: cfg.n_key_value_head,
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
})
}
}

View File

@ -0,0 +1,18 @@
[package]
name = "candle-flash-attn"
version = "0.1.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"] }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }

182
candle-flash-attn/build.rs Normal file
View File

@ -0,0 +1,182 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::io::Write;
use std::path::PathBuf;
fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu");
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
println!("cargo:rerun-if-changed=kernels/flash.h");
println!("cargo:rerun-if-changed=kernels/philox.cuh");
println!("cargo:rerun-if-changed=kernels/softmax.h");
println!("cargo:rerun-if-changed=kernels/utils.h");
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let mut out_dir = PathBuf::from(out_dir);
// TODO: Getting up two levels avoid having to recompile this too often, however it's likely
// not a safe assumption.
out_dir.pop();
out_dir.pop();
set_cuda_include_dir()?;
let compute_cap = compute_cap()?;
let mut command = std::process::Command::new("nvcc");
let out_file = out_dir.join("libflashattention.a");
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
let should_compile = if out_file.exists() {
let out_modified = out_file.metadata()?.modified()?;
let in_modified = cu_file.metadata()?.modified()?;
in_modified.duration_since(out_modified).is_ok()
} else {
true
};
if should_compile {
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", out_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(())
}
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

@ -0,0 +1 @@
Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933

View File

@ -0,0 +1,41 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}
template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -0,0 +1,141 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Random state.
// at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);

View File

@ -0,0 +1,92 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
}
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
void *v_ptr,
void *o_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t b,
uint32_t h,
uint32_t h_k,
uint32_t d,
uint32_t d_rounded,
float softmax_scale,
uint32_t seqlen_q,
uint32_t seqlen_k,
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
int is_causal
) {
Flash_fwd_params params;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
// All stride are in elements, not bytes.
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_ptr = o_ptr;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd_<cutlass::half_t, 32>(params, stream);
}

View File

@ -0,0 +1,579 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "philox.cuh"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
// TODO: Shouldn't this be size<1>?
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) {
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
flash::reduce_sum(scores, scores_sum);
} else {
Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scores_max); ++mi) {
float scores_max_cur = !Check_inf
? scores_max(mi)
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scores_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(scores_sum);
flash::reduce_sum(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
// TODO(laurent): reactivate the following
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
Tensor scores_sum = make_fragment_like(scores_max);
//
// PREDICATES
//
// // Allocate predicate tensors for m and n
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
// for (int i = 0; i < size(tScQ); ++i) {
// printf("%d ", get<0>(tScQ(i)));
// }
// printf("\n");
// for (int i = 0; i < size(tScQ); ++i) {
// printf("%d ", get<1>(tScQ(i)));
// }
// printf("\n");
// }
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}
// Prologue
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem
// // copy(tQrQ, tQsQ);
// flash::cp_async_wait<0>();
// __syncthreads();
// // if (cute::thread(1, 0)) { print(tQsQ); }
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
// // if (cute::thread0()) { print(sQNoSwizzle); }
if (Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<0>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<1>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
}
// auto seeds = at::cuda::philox::unpack(params.philox_args);
// unsigned long long seed = std::get<0>(seeds);
// unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
unsigned long long seed = 0;
unsigned long long offset = 0;
clear(acc_o);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
// static_assert(decltype(size<0>(taccScS))::value == 4);
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM);
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps);
}
// if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
);
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps);
}
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
}
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor lse = make_fragment_like(scores_sum);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
// if (cute::thread0()) { print(acc_o_rowcol); }
// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0>(taccOcO))::value == 4);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -0,0 +1,251 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
// #include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
// if (smem_size >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// }
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
}
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96;
// auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// These two are always slower
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128;
// auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160;
// auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}

View File

@ -0,0 +1,366 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
// No_double_buffer is another option to reduce smem usage, but will slow things down.
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_bwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
static constexpr bool No_double_buffer = No_double_buffer_;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
static_assert(kNWarps % AtomLayoutMSdP == 0);
static_assert(kNWarps % AtomLayoutNdKV == 0);
static_assert(kNWarps % AtomLayoutMdQ == 0);
using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQdO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQdO = decltype(tile_to_shape(
SmemLayoutAtomQdO{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
// SmemLayoutAtomQdO{},
SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
static_assert(kBlockN >= 64);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static constexpr int kPBlockN = 64;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3;
using SmemLayoutAtomPdS = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{},
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
Stride<Int<kPBlockN>, _1>>{}));
using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposed = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{},
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
static constexpr int kSmemSize1colblock = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
+ kSmemdSSize + kSmemPSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
// to affect speed in practice.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,165 @@
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
#pragma once
// Philox CUDA.
namespace flash {
struct ull2 {
unsigned long long x;
unsigned long long y;
};
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
: "=l"(tmp)
: "r"(a), "r"(b));
res = (uint2*)(&tmp);
return *res;
}
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
inline __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;
constexpr unsigned long kPhilox10B = 0xBB67AE85;
uint2 key = reinterpret_cast<uint2&>(seed);
uint4 counter;
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset;
tmp->y = subsequence;
#pragma unroll
for (int i = 0; i < 6; i++) {
counter = philox_single_round(counter, key);
key.x += (kPhilox10A);
key.y += (kPhilox10B);
}
uint4 output = philox_single_round(counter, key);
return output;
}
} // namespace flash
namespace {
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset)
: STATE(0)
, seed_(seed)
, offset_(offset)
, key(reinterpret_cast<const uint2&>(seed)) {
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4;
tmp->y = subsequence;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ inline uint4 operator()() {
// // if (STATE == 0) {
// uint4 counter_ = counter;
// uint2 key_ = key;
// // 7-round philox
// #pragma unroll
// for (int i = 0; i < 6; i++) {
// counter_ = flash::philox_single_round(counter_, key_);
// key_.x += (kPhilox10A);
// key_.y += (kPhilox10B);
// }
// // output = philox_single_round(counter_, key_);
// uint4 output = flash::philox_single_round(counter_, key_);
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// // }
// incr();
// // }
// // return a float4 directly
// // unsigned long ret;
// // switch(STATE) {
// // case 0: ret = output.x; break;
// // case 1: ret = output.y; break;
// // case 2: ret = output.z; break;
// // case 3: ret = output.w; break;
// //}
// // STATE = (STATE + 1) % 4;
// return output;
return flash::philox(seed_, offset_, offset_);
}
private:
unsigned long long offset_, seed_;
struct ull2 {
uint64_t x;
uint64_t y;
};
uint4 counter;
// uint4 output;
const uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ uint4 incr128 (uint4 ctr)
{
uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t"
"addc.cc.u32 %2, %6, %10;\n\t"
"addc.u32 %3, %7, %11;\n\t"
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
"n"(1), "n"(0), "n"(0), "n"(0));
return res;
}
__device__ inline void incr() {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter = incr128(counter);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
// static const unsigned long kPhiloxSA = 0xD2511F53;
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
} // namespace

View File

@ -0,0 +1,272 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "philox.cuh"
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
SumOp<float> sum_op;
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
}
}
template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
const uint32_t warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
// if (cute::thread0()) {
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
// print(tensor(make_coord(i, mi), _));
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
// }
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
#pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
tensor(mi, ni) = -INFINITY;
}
}
// if (cute::thread0()) {
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
// print(tensor(_, make_coord(j, ni)));
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
// }
}
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset,
uint32_t block_row_start, uint32_t block_col_start,
uint32_t block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
};
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
uint16_t rnd_16[16];
#pragma unroll
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
#pragma unroll
for (int j = 0; j < 2; j++) {
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t mask;
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
tensor_uint32(i) &= mask;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
} else {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
} // namespace flash

View File

@ -0,0 +1,66 @@
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 96) { \
constexpr static int kHeadDim = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 160) { \
constexpr static int kHeadDim = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 224) { \
constexpr static int kHeadDim = 224; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()

View File

@ -0,0 +1,388 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t relu2(const uint32_t x);
template<>
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T>
inline __device__ uint32_t convert_relu2(const float2 x);
template<>
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
return res;
}
template<>
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = flash::half2_unpack<T>(a);
float2 bf = flash::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = flash::hfma2_to_float<T>(a.x, b.x);
sum += flash::hfma2_to_float<T>(a.y, b.y);
sum += flash::hfma2_to_float<T>(a.z, b.z);
sum += flash::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<0, 1>(l),
get<1, 1, 1>(l));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type;
// HACK: this requires tensor to be "contiguous"
Tensor tensor_uint32 = recast<uint32_t>(tensor);
#pragma unroll
for (int i = 0; i < size(tensor_uint32); ++i) {
tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>);
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// HACK: this requires tensor to be "contiguous"
Tensor tensor_float2 = recast<float2>(tensor);
Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
#pragma unroll
for (int i = 0; i < size(out_uint32); ++i) {
out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
}
Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
#else
Tensor out = flash::convert_type<To_type>(tensor);
flash::relu_(out);
#endif
return out;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
clear(D(_, m, _));
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
// I think it's because the copies are under an if statement.
// if (Is_even_K) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// }
// }
// } else { // It's slightly faster in this case if iterate over K first
// #pragma unroll
// for (int k = 0; k < size<2>(S); ++k) {
// if (predicate_K(k)) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// }
// }
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
// if (Clear_OOB_MN || Is_even_MN) {
// clear(D(_, _, k));
// } else {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
// clear(D(_, m, k));
// }
// }
// }
// }
// }
// }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -0,0 +1,35 @@
use core::ffi::{c_int, c_void};
extern "C" {
pub(crate) fn run_mha(
q_ptr: *const c_void,
k_ptr: *const c_void,
v_ptr: *const c_void,
o_ptr: *const c_void,
q_batch_stride: u32,
k_batch_stride: u32,
v_batch_stride: u32,
q_row_stride: u32,
k_row_stride: u32,
v_row_stride: u32,
q_head_stride: u32,
k_head_stride: u32,
v_head_stride: u32,
b: u32,
h: u32,
h_k: u32,
d: u32,
d_rounded: u32,
softmax_scale: f32,
seqlen_q: u32,
seqlen_k: u32,
seqlen_q_rounded: u32,
seqlen_k_rounded: u32,
is_causal: c_int,
);
}

View File

@ -0,0 +1,59 @@
mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Error, Layout, Result, Shape};
use half::f16;
pub struct FlashHdim32Sm80;
impl candle::CustomOp3 for FlashHdim32Sm80 {
fn name(&self) -> &'static str {
"flash-hdim32-sm80"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
Err(Error::Wrapped("no cpu support for flash-attn".into()))
}
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
_q_l: &Layout,
k: &candle::CudaStorage,
_k_l: &Layout,
v: &candle::CudaStorage,
_v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
let dev = q.device();
let out_shape = Shape::from(&[1]);
let q = q.as_cuda_slice::<f16>()?;
let k = k.as_cuda_slice::<f16>()?;
let v = v.as_cuda_slice::<f16>()?;
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
ffi::run_mha(
q_ptr, k_ptr, v_ptr, dst_ptr, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 1, 1,
1, 1, 1,
)
}
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone());
Ok((dst, out_shape))
}
}