Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab. * More flash attn. * Set up the flash attn parameters. * Get things to compile locally. * Move the flash attention files in a different directory. * Build the static C library with nvcc. * Add more flash attention. * Update the build part. * Better caching. * Exclude flash attention from the default workspace. * Put flash-attn behind a feature gate. * Get the flash attn kernel to run. * Move the flags to a more appropriate place. * Enable flash attention in llama. * Use flash attention in llama.
This commit is contained in:
parent
c97d51243c
commit
d9f9c859af
|
@ -0,0 +1,3 @@
|
||||||
|
[submodule "candle-examples/examples/flash-attn/cutlass"]
|
||||||
|
path = candle-flash-attn/cutlass
|
||||||
|
url = https://github.com/NVIDIA/cutlass.git
|
|
@ -9,6 +9,7 @@ members = [
|
||||||
"candle-wasm-examples/whisper",
|
"candle-wasm-examples/whisper",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
"candle-flash-attn",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ pub use dtype::{DType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
pub use op::CustomOp1;
|
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
|
|
|
@ -14,6 +14,7 @@ readme = "README.md"
|
||||||
candle = { path = "../candle-core" }
|
candle = { path = "../candle-core" }
|
||||||
candle-nn = { path = "../candle-nn" }
|
candle-nn = { path = "../candle-nn" }
|
||||||
candle-transformers = { path = "../candle-transformers" }
|
candle-transformers = { path = "../candle-transformers" }
|
||||||
|
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
@ -37,4 +38,5 @@ anyhow = { workspace = true }
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
|
|
|
@ -6,11 +6,13 @@ use std::path::PathBuf;
|
||||||
struct KernelDirectories {
|
struct KernelDirectories {
|
||||||
kernel_dir: &'static str,
|
kernel_dir: &'static str,
|
||||||
rust_target: &'static str,
|
rust_target: &'static str,
|
||||||
|
include_dirs: &'static [&'static str],
|
||||||
}
|
}
|
||||||
|
|
||||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||||
kernel_dir: "examples/custom-ops/kernels/",
|
kernel_dir: "examples/custom-ops/kernels/",
|
||||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||||
|
include_dirs: &[],
|
||||||
}];
|
}];
|
||||||
|
|
||||||
impl KernelDirectories {
|
impl KernelDirectories {
|
||||||
|
@ -32,12 +34,15 @@ impl KernelDirectories {
|
||||||
{
|
{
|
||||||
let mut command = std::process::Command::new("nvcc");
|
let mut command = std::process::Command::new("nvcc");
|
||||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||||
|
let include_dirs: Vec<String> =
|
||||||
|
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||||
command
|
command
|
||||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
.arg("--ptx")
|
.arg("--ptx")
|
||||||
.args(["--default-stream", "per-thread"])
|
.args(["--default-stream", "per-thread"])
|
||||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||||
.arg(format!("-I/{}", self.kernel_dir))
|
.arg(format!("-I/{}", self.kernel_dir))
|
||||||
|
.args(include_dirs)
|
||||||
.arg(cu_file);
|
.arg(cu_file);
|
||||||
let output = command
|
let output = command
|
||||||
.spawn()
|
.spawn()
|
||||||
|
@ -221,6 +226,7 @@ fn compute_cap() -> Result<usize> {
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
|
|
||||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
compute_cap = compute_cap_str
|
compute_cap = compute_cap_str
|
||||||
.parse::<usize>()
|
.parse::<usize>()
|
||||||
|
|
|
@ -116,6 +116,9 @@ struct Args {
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
v2: bool,
|
v2: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
@ -124,7 +127,7 @@ fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b(args.use_flash_attn);
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
||||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
|
|
|
@ -13,10 +13,11 @@ pub struct Config {
|
||||||
pub n_head: usize,
|
pub n_head: usize,
|
||||||
pub n_embd: usize,
|
pub n_embd: usize,
|
||||||
pub n_key_value_head: usize,
|
pub n_key_value_head: usize,
|
||||||
|
pub use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn config_7b() -> Self {
|
pub fn config_7b(use_flash_attn: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
hidden_size: 4096,
|
hidden_size: 4096,
|
||||||
intermediate_size: 11008,
|
intermediate_size: 11008,
|
||||||
|
@ -25,6 +26,7 @@ impl Config {
|
||||||
n_head: 32,
|
n_head: 32,
|
||||||
n_embd: 4096,
|
n_embd: 4096,
|
||||||
n_key_value_head: 32,
|
n_key_value_head: 32,
|
||||||
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -140,6 +142,17 @@ struct CausalSelfAttention {
|
||||||
n_key_value_head: usize,
|
n_key_value_head: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||||
|
q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
|
@ -202,12 +215,17 @@ impl CausalSelfAttention {
|
||||||
|
|
||||||
let k = self.repeat_kv(k)?;
|
let k = self.repeat_kv(k)?;
|
||||||
let v = self.repeat_kv(v)?;
|
let v = self.repeat_kv(v)?;
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
|
||||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let y = if self.use_flash_attn {
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
flash_attn(&q, &k, &v)?
|
||||||
let att = att.softmax(D::Minus1)?;
|
} else {
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
|
let att = att.softmax(D::Minus1)?;
|
||||||
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
|
att.matmul(&v.contiguous()?)?
|
||||||
|
};
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = y.to_dtype(x_dtype)?;
|
let y = y.to_dtype(x_dtype)?;
|
||||||
let y = self.o_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
|
@ -245,6 +263,7 @@ impl CausalSelfAttention {
|
||||||
n_key_value_head: cfg.n_key_value_head,
|
n_key_value_head: cfg.n_key_value_head,
|
||||||
head_dim: cfg.hidden_size / cfg.n_head,
|
head_dim: cfg.hidden_size / cfg.n_head,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
[package]
|
||||||
|
name = "candle-flash-attn"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
description = "Flash attention layer for the candle ML framework."
|
||||||
|
repository = "https://github.com/LaurentMazare/candle"
|
||||||
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
|
categories = ["science"]
|
||||||
|
license = "MIT/Apache-2.0"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
candle = { path = "../candle-core", features = ["cuda"] }
|
||||||
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
anyhow = { version = "1", features = ["backtrace"] }
|
|
@ -0,0 +1,182 @@
|
||||||
|
#![allow(unused)]
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/flash.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/philox.cuh");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/softmax.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/utils.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
||||||
|
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
||||||
|
|
||||||
|
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||||
|
let mut out_dir = PathBuf::from(out_dir);
|
||||||
|
// TODO: Getting up two levels avoid having to recompile this too often, however it's likely
|
||||||
|
// not a safe assumption.
|
||||||
|
out_dir.pop();
|
||||||
|
out_dir.pop();
|
||||||
|
set_cuda_include_dir()?;
|
||||||
|
let compute_cap = compute_cap()?;
|
||||||
|
|
||||||
|
let mut command = std::process::Command::new("nvcc");
|
||||||
|
let out_file = out_dir.join("libflashattention.a");
|
||||||
|
|
||||||
|
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
|
||||||
|
let should_compile = if out_file.exists() {
|
||||||
|
let out_modified = out_file.metadata()?.modified()?;
|
||||||
|
let in_modified = cu_file.metadata()?.modified()?;
|
||||||
|
in_modified.duration_since(out_modified).is_ok()
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
if should_compile {
|
||||||
|
command
|
||||||
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.arg("--lib")
|
||||||
|
.args(["-o", out_file.to_str().unwrap()])
|
||||||
|
.args(["--default-stream", "per-thread"])
|
||||||
|
.arg("-Icutlass/include")
|
||||||
|
.arg("--expt-relaxed-constexpr")
|
||||||
|
.arg(cu_file);
|
||||||
|
let output = command
|
||||||
|
.spawn()
|
||||||
|
.context("failed spawning nvcc")?
|
||||||
|
.wait_with_output()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-link-search={}", out_dir.display());
|
||||||
|
println!("cargo:rustc-link-lib=flashattention");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
|
||||||
|
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
||||||
|
finishing to run for some reason. Calling nvcc manually worked fine.
|
||||||
|
cc::Build::new()
|
||||||
|
.cuda(true)
|
||||||
|
.include("cutlass/include")
|
||||||
|
.flag("--expt-relaxed-constexpr")
|
||||||
|
.flag("--default-stream")
|
||||||
|
.flag("per-thread")
|
||||||
|
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
||||||
|
.compile("flashattn");
|
||||||
|
*/
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_cuda_include_dir() -> Result<()> {
|
||||||
|
// NOTE: copied from cudarc build.rs.
|
||||||
|
let env_vars = [
|
||||||
|
"CUDA_PATH",
|
||||||
|
"CUDA_ROOT",
|
||||||
|
"CUDA_TOOLKIT_ROOT_DIR",
|
||||||
|
"CUDNN_LIB",
|
||||||
|
];
|
||||||
|
let env_vars = env_vars
|
||||||
|
.into_iter()
|
||||||
|
.map(std::env::var)
|
||||||
|
.filter_map(Result::ok)
|
||||||
|
.map(Into::<PathBuf>::into);
|
||||||
|
|
||||||
|
let roots = [
|
||||||
|
"/usr",
|
||||||
|
"/usr/local/cuda",
|
||||||
|
"/opt/cuda",
|
||||||
|
"/usr/lib/cuda",
|
||||||
|
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||||
|
"C:/CUDA",
|
||||||
|
];
|
||||||
|
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||||
|
let root = env_vars
|
||||||
|
.chain(roots)
|
||||||
|
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||||
|
.context("cannot find include/cuda.h")?;
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||||
|
root.join("include").display()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn compute_cap() -> Result<usize> {
|
||||||
|
// Grab compute code from nvidia-smi
|
||||||
|
let mut compute_cap = {
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.arg("--query-gpu=compute_cap")
|
||||||
|
.arg("--format=csv")
|
||||||
|
.output()
|
||||||
|
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||||
|
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||||
|
let mut lines = out.lines();
|
||||||
|
assert_eq!(
|
||||||
|
lines.next().context("missing line in stdout")?,
|
||||||
|
"compute_cap"
|
||||||
|
);
|
||||||
|
let cap = lines
|
||||||
|
.next()
|
||||||
|
.context("missing line in stdout")?
|
||||||
|
.replace('.', "");
|
||||||
|
cap.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Grab available GPU codes from nvcc and select the highest one
|
||||||
|
let max_nvcc_code = {
|
||||||
|
let out = std::process::Command::new("nvcc")
|
||||||
|
.arg("--list-gpu-code")
|
||||||
|
.output()
|
||||||
|
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||||
|
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||||
|
|
||||||
|
let out = out.lines().collect::<Vec<&str>>();
|
||||||
|
let mut codes = Vec::with_capacity(out.len());
|
||||||
|
for code in out {
|
||||||
|
let code = code.split('_').collect::<Vec<&str>>();
|
||||||
|
if !code.is_empty() && code.contains(&"sm") {
|
||||||
|
if let Ok(num) = code[1].parse::<usize>() {
|
||||||
|
codes.push(num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
codes.sort();
|
||||||
|
if !codes.contains(&compute_cap) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
*codes.last().unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||||
|
// then choose the highest gpu code in nvcc
|
||||||
|
if compute_cap > max_nvcc_code {
|
||||||
|
println!(
|
||||||
|
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||||
|
);
|
||||||
|
compute_cap = max_nvcc_code;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
|
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
|
compute_cap = compute_cap_str
|
||||||
|
.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||||
|
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||||
|
Ok(compute_cap)
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933
|
|
@ -0,0 +1,41 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<bool Varlen=true>
|
||||||
|
struct BlockInfo {
|
||||||
|
|
||||||
|
template<typename Params>
|
||||||
|
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||||
|
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||||
|
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||||
|
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||||
|
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename index_t>
|
||||||
|
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||||
|
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename index_t>
|
||||||
|
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||||
|
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int sum_s_q;
|
||||||
|
const int sum_s_k;
|
||||||
|
const uint32_t actual_seqlen_q;
|
||||||
|
const uint32_t actual_seqlen_k;
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace flash
|
|
@ -0,0 +1,141 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// #ifdef OLD_GENERATOR_PATH
|
||||||
|
// #include <ATen/CUDAGeneratorImpl.h>
|
||||||
|
// #else
|
||||||
|
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||||
|
// #endif
|
||||||
|
//
|
||||||
|
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||||
|
|
||||||
|
|
||||||
|
constexpr int TOTAL_DIM = 0;
|
||||||
|
constexpr int H_DIM = 1;
|
||||||
|
constexpr int D_DIM = 2;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct Qkv_params {
|
||||||
|
using index_t = uint32_t;
|
||||||
|
// The QKV matrices.
|
||||||
|
void *__restrict__ q_ptr;
|
||||||
|
void *__restrict__ k_ptr;
|
||||||
|
void *__restrict__ v_ptr;
|
||||||
|
|
||||||
|
// The stride between rows of the Q, K and V matrices.
|
||||||
|
index_t q_batch_stride;
|
||||||
|
index_t k_batch_stride;
|
||||||
|
index_t v_batch_stride;
|
||||||
|
index_t q_row_stride;
|
||||||
|
index_t k_row_stride;
|
||||||
|
index_t v_row_stride;
|
||||||
|
index_t q_head_stride;
|
||||||
|
index_t k_head_stride;
|
||||||
|
index_t v_head_stride;
|
||||||
|
|
||||||
|
// The number of heads.
|
||||||
|
int h, h_k;
|
||||||
|
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||||
|
// different from nheads (query).
|
||||||
|
int h_h_k_ratio; // precompute h / h_k,
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct Flash_fwd_params : public Qkv_params {
|
||||||
|
|
||||||
|
// The O matrix (output).
|
||||||
|
void * __restrict__ o_ptr;
|
||||||
|
|
||||||
|
// The stride between rows of O.
|
||||||
|
index_t o_batch_stride;
|
||||||
|
index_t o_row_stride;
|
||||||
|
index_t o_head_stride;
|
||||||
|
|
||||||
|
// The pointer to the P matrix.
|
||||||
|
void * __restrict__ p_ptr;
|
||||||
|
|
||||||
|
// The pointer to the softmax sum.
|
||||||
|
void * __restrict__ softmax_lse_ptr;
|
||||||
|
|
||||||
|
// The dimensions.
|
||||||
|
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||||
|
|
||||||
|
// The scaling factors for the kernel.
|
||||||
|
float scale_softmax;
|
||||||
|
float scale_softmax_log2;
|
||||||
|
|
||||||
|
// array of length b+1 holding starting offset of each sequence.
|
||||||
|
int * __restrict__ cu_seqlens_q;
|
||||||
|
int * __restrict__ cu_seqlens_k;
|
||||||
|
|
||||||
|
int *__restrict__ blockmask;
|
||||||
|
|
||||||
|
// The dropout probability (probability of keeping an activation).
|
||||||
|
float p_dropout;
|
||||||
|
// uint32_t p_dropout_in_uint;
|
||||||
|
// uint16_t p_dropout_in_uint16_t;
|
||||||
|
uint8_t p_dropout_in_uint8_t;
|
||||||
|
|
||||||
|
// Scale factor of 1 / (1 - p_dropout).
|
||||||
|
float rp_dropout;
|
||||||
|
float scale_softmax_rp_dropout;
|
||||||
|
|
||||||
|
// Random state.
|
||||||
|
// at::PhiloxCudaState philox_args;
|
||||||
|
|
||||||
|
bool is_bf16;
|
||||||
|
bool is_causal;
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct Flash_bwd_params : public Flash_fwd_params {
|
||||||
|
|
||||||
|
// The dO and dQKV matrices.
|
||||||
|
void *__restrict__ do_ptr;
|
||||||
|
void *__restrict__ dq_ptr;
|
||||||
|
void *__restrict__ dk_ptr;
|
||||||
|
void *__restrict__ dv_ptr;
|
||||||
|
|
||||||
|
// To accumulate dQ
|
||||||
|
void *__restrict__ dq_accum_ptr;
|
||||||
|
void *__restrict__ dk_accum_ptr;
|
||||||
|
void *__restrict__ dv_accum_ptr;
|
||||||
|
|
||||||
|
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
||||||
|
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
||||||
|
// dv_accum_ptr;
|
||||||
|
|
||||||
|
// The stride between rows of the dO, dQ, dK and dV matrices.
|
||||||
|
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||||
|
// The code probably won't work for arrays larger than 2GB.
|
||||||
|
index_t do_batch_stride;
|
||||||
|
index_t do_row_stride;
|
||||||
|
index_t do_head_stride;
|
||||||
|
index_t dq_batch_stride;
|
||||||
|
index_t dk_batch_stride;
|
||||||
|
index_t dv_batch_stride;
|
||||||
|
index_t dq_row_stride;
|
||||||
|
index_t dk_row_stride;
|
||||||
|
index_t dv_row_stride;
|
||||||
|
index_t dq_head_stride;
|
||||||
|
index_t dk_head_stride;
|
||||||
|
index_t dv_head_stride;
|
||||||
|
|
||||||
|
// The pointer to the softmax d sum.
|
||||||
|
void *__restrict__ dsoftmax_sum;
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
|
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
|
@ -0,0 +1,92 @@
|
||||||
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
|
||||||
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::half_t;
|
||||||
|
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// // For dropout there might be a lot of register spilling?
|
||||||
|
// // These two are very slow due to register spilling
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
|
||||||
|
// // This one is slightly slower
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
template<>
|
||||||
|
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
extern "C" void run_mha(
|
||||||
|
void *q_ptr,
|
||||||
|
void *k_ptr,
|
||||||
|
void *v_ptr,
|
||||||
|
void *o_ptr,
|
||||||
|
|
||||||
|
uint32_t q_batch_stride,
|
||||||
|
uint32_t k_batch_stride,
|
||||||
|
uint32_t v_batch_stride,
|
||||||
|
uint32_t q_row_stride,
|
||||||
|
uint32_t k_row_stride,
|
||||||
|
uint32_t v_row_stride,
|
||||||
|
uint32_t q_head_stride,
|
||||||
|
uint32_t k_head_stride,
|
||||||
|
uint32_t v_head_stride,
|
||||||
|
|
||||||
|
uint32_t b,
|
||||||
|
uint32_t h,
|
||||||
|
uint32_t h_k,
|
||||||
|
uint32_t d,
|
||||||
|
uint32_t d_rounded,
|
||||||
|
float softmax_scale,
|
||||||
|
|
||||||
|
uint32_t seqlen_q,
|
||||||
|
uint32_t seqlen_k,
|
||||||
|
uint32_t seqlen_q_rounded,
|
||||||
|
uint32_t seqlen_k_rounded,
|
||||||
|
|
||||||
|
int is_causal
|
||||||
|
) {
|
||||||
|
Flash_fwd_params params;
|
||||||
|
// Reset the parameters
|
||||||
|
memset(¶ms, 0, sizeof(params));
|
||||||
|
|
||||||
|
// Set the pointers and strides.
|
||||||
|
params.q_ptr = q_ptr;
|
||||||
|
params.k_ptr = k_ptr;
|
||||||
|
params.v_ptr = v_ptr;
|
||||||
|
// All stride are in elements, not bytes.
|
||||||
|
params.q_row_stride = q_row_stride;
|
||||||
|
params.k_row_stride = k_row_stride;
|
||||||
|
params.v_row_stride = v_row_stride;
|
||||||
|
params.q_head_stride = q_head_stride;
|
||||||
|
params.k_head_stride = k_head_stride;
|
||||||
|
params.v_head_stride = v_head_stride;
|
||||||
|
params.o_ptr = o_ptr;
|
||||||
|
|
||||||
|
// Set the dimensions.
|
||||||
|
params.b = b;
|
||||||
|
params.h = h;
|
||||||
|
params.h_k = h_k;
|
||||||
|
params.h_h_k_ratio = h / h_k;
|
||||||
|
params.seqlen_q = seqlen_q;
|
||||||
|
params.seqlen_k = seqlen_k;
|
||||||
|
params.seqlen_q_rounded = seqlen_q_rounded;
|
||||||
|
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||||
|
params.d = d;
|
||||||
|
params.d_rounded = d_rounded;
|
||||||
|
params.is_causal = is_causal;
|
||||||
|
|
||||||
|
// Set the different scale values.
|
||||||
|
params.scale_softmax = softmax_scale;
|
||||||
|
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||||
|
|
||||||
|
cudaStream_t stream = 0; // Use the default stream.
|
||||||
|
run_mha_fwd_<cutlass::half_t, 32>(params, stream);
|
||||||
|
}
|
|
@ -0,0 +1,579 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cute/algorithm/copy.hpp>
|
||||||
|
#include <cute/algorithm/gemm.hpp>
|
||||||
|
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
#include <cutlass/numeric_conversion.h>
|
||||||
|
|
||||||
|
#include "block_info.h"
|
||||||
|
#include "kernel_traits.h"
|
||||||
|
#include "utils.h"
|
||||||
|
#include "softmax.h"
|
||||||
|
#include "philox.cuh"
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int MMA_M,
|
||||||
|
class... Args,
|
||||||
|
class TiledMMA>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
auto
|
||||||
|
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||||
|
TiledMMA const& tiled_mma) {
|
||||||
|
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||||
|
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||||
|
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||||
|
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||||
|
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||||
|
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||||
|
Stride<_1, Int<MMAStride_M>> >{},
|
||||||
|
make_layout(size<2>(TileShape_MNK{})));
|
||||||
|
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
|
||||||
|
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int MMA_M,
|
||||||
|
class... Args,
|
||||||
|
class TiledMMA>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
auto
|
||||||
|
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||||
|
TiledMMA const& tiled_mma) {
|
||||||
|
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||||
|
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||||
|
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||||
|
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||||
|
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||||
|
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||||
|
Stride<_1, Int<MMAStride_M>> >{},
|
||||||
|
// TODO: Shouldn't this be size<1>?
|
||||||
|
make_layout(size<2>(TileShape_MNK{})));
|
||||||
|
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
|
||||||
|
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||||
|
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||||
|
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||||
|
if (Is_first) {
|
||||||
|
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
||||||
|
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||||
|
flash::reduce_sum(scores, scores_sum);
|
||||||
|
} else {
|
||||||
|
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||||
|
copy(scores_max, scores_max_prev);
|
||||||
|
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||||
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||||
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size(scores_max); ++mi) {
|
||||||
|
float scores_max_cur = !Check_inf
|
||||||
|
? scores_max(mi)
|
||||||
|
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
||||||
|
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||||
|
scores_sum(mi) *= scores_scale;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||||
|
}
|
||||||
|
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||||
|
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
||||||
|
flash::reduce_sum(scores, scores_sum_cur);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
||||||
|
inline __device__ void write_softmax_to_gmem(
|
||||||
|
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
|
||||||
|
) {
|
||||||
|
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
||||||
|
Layout l = tOrP.layout();
|
||||||
|
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
||||||
|
// TODO(laurent): reactivate the following
|
||||||
|
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
||||||
|
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||||
|
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||||
|
|
||||||
|
using Element = typename Kernel_traits::Element;
|
||||||
|
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||||
|
using index_t = typename Kernel_traits::index_t;
|
||||||
|
|
||||||
|
// Shared memory.
|
||||||
|
extern __shared__ char smem_[];
|
||||||
|
|
||||||
|
// The thread index.
|
||||||
|
const int tidx = threadIdx.x;
|
||||||
|
|
||||||
|
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||||
|
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||||
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||||
|
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||||
|
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||||
|
|
||||||
|
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
|
||||||
|
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
||||||
|
|
||||||
|
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
||||||
|
if (Is_causal) {
|
||||||
|
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
|
||||||
|
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
||||||
|
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||||
|
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
||||||
|
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
|
||||||
|
|
||||||
|
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||||
|
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||||
|
// We move K and V to the last block.
|
||||||
|
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||||
|
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||||
|
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||||
|
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||||
|
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||||
|
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
|
||||||
|
|
||||||
|
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||||
|
make_stride(params.q_row_stride, _1{}));
|
||||||
|
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||||
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||||
|
make_stride(params.k_row_stride, _1{}));
|
||||||
|
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||||
|
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||||
|
make_stride(params.v_row_stride, _1{}));
|
||||||
|
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
|
||||||
|
Shape<Int<kBlockM>, Int<kBlockN>>{},
|
||||||
|
make_stride(params.seqlen_k_rounded, _1{}));
|
||||||
|
|
||||||
|
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||||
|
typename Kernel_traits::SmemLayoutQ{});
|
||||||
|
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
|
||||||
|
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
|
||||||
|
typename Kernel_traits::SmemLayoutKV{});
|
||||||
|
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||||
|
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||||
|
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||||
|
|
||||||
|
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
|
||||||
|
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
|
||||||
|
|
||||||
|
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||||
|
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||||
|
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||||
|
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||||
|
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||||
|
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||||
|
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
|
||||||
|
|
||||||
|
typename Kernel_traits::TiledMma tiled_mma;
|
||||||
|
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||||
|
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||||
|
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||||
|
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
||||||
|
|
||||||
|
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
||||||
|
|
||||||
|
//
|
||||||
|
// Copy Atom retiling
|
||||||
|
//
|
||||||
|
|
||||||
|
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||||
|
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||||
|
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
||||||
|
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||||
|
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
||||||
|
|
||||||
|
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||||
|
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||||
|
|
||||||
|
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
|
||||||
|
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||||
|
|
||||||
|
// TODO: this might need to change if we change the mma instruction in SM70
|
||||||
|
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||||
|
Tensor scores_sum = make_fragment_like(scores_max);
|
||||||
|
|
||||||
|
//
|
||||||
|
// PREDICATES
|
||||||
|
//
|
||||||
|
|
||||||
|
// // Allocate predicate tensors for m and n
|
||||||
|
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
|
||||||
|
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
|
||||||
|
|
||||||
|
// Construct identity layout for sQ and sK
|
||||||
|
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||||
|
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||||
|
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
|
||||||
|
// if (cute::thread0()) {
|
||||||
|
// print(tScQ.layout()); printf("\n");
|
||||||
|
// for (int i = 0; i < size(tScQ); ++i) {
|
||||||
|
// printf("%d ", get<0>(tScQ(i)));
|
||||||
|
// }
|
||||||
|
// printf("\n");
|
||||||
|
// for (int i = 0; i < size(tScQ); ++i) {
|
||||||
|
// printf("%d ", get<1>(tScQ(i)));
|
||||||
|
// }
|
||||||
|
// printf("\n");
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Repeat the partitioning with identity layouts
|
||||||
|
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||||
|
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||||
|
|
||||||
|
// Allocate predicate tensors for k
|
||||||
|
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||||
|
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||||
|
|
||||||
|
// Set predicates for k bounds
|
||||||
|
if (!Is_even_K) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prologue
|
||||||
|
|
||||||
|
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||||
|
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||||
|
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||||
|
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||||
|
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||||
|
|
||||||
|
// // Copy rmem to smem
|
||||||
|
// // copy(tQrQ, tQsQ);
|
||||||
|
// flash::cp_async_wait<0>();
|
||||||
|
// __syncthreads();
|
||||||
|
// // if (cute::thread(1, 0)) { print(tQsQ); }
|
||||||
|
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
|
||||||
|
// // if (cute::thread0()) { print(sQNoSwizzle); }
|
||||||
|
|
||||||
|
if (Kernel_traits::Share_Q_K_smem) {
|
||||||
|
flash::cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||||
|
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_block = n_block_max - 1;
|
||||||
|
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||||
|
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||||
|
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||||
|
cute::cp_async_fence();
|
||||||
|
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||||
|
// __syncthreads();
|
||||||
|
|
||||||
|
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
||||||
|
flash::cp_async_wait<1>();
|
||||||
|
__syncthreads();
|
||||||
|
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||||
|
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
||||||
|
}
|
||||||
|
|
||||||
|
// auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||||
|
// unsigned long long seed = std::get<0>(seeds);
|
||||||
|
// unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||||
|
unsigned long long seed = 0;
|
||||||
|
unsigned long long offset = 0;
|
||||||
|
|
||||||
|
clear(acc_o);
|
||||||
|
|
||||||
|
// For performance reason, we separate out two kinds of iterations:
|
||||||
|
// those that need masking on S, and those that don't.
|
||||||
|
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||||
|
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||||
|
// We will have at least 1 "masking" iteration.
|
||||||
|
|
||||||
|
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||||
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||||
|
clear(acc_s);
|
||||||
|
flash::cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Advance gV
|
||||||
|
if (masking_step > 0) {
|
||||||
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||||
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||||
|
} else {
|
||||||
|
// Clear the smem tiles to account for predicated off loads
|
||||||
|
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||||
|
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||||
|
);
|
||||||
|
}
|
||||||
|
cute::cp_async_fence();
|
||||||
|
|
||||||
|
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||||
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
||||||
|
);
|
||||||
|
// if (cute::thread0()) { print(acc_s); }
|
||||||
|
|
||||||
|
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||||
|
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||||
|
// if (cute::thread0()) { print(scores); }
|
||||||
|
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||||
|
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||||
|
// can produce Inf / NaN.
|
||||||
|
if (!Is_causal) {
|
||||||
|
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||||
|
} else {
|
||||||
|
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||||
|
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||||
|
// static_assert(decltype(size<0>(taccScS))::value == 4);
|
||||||
|
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
|
||||||
|
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||||
|
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
|
||||||
|
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||||
|
// m_block * kBlockM);
|
||||||
|
// Idk why it's get<1> and not get<0> of the stride.
|
||||||
|
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||||
|
// I can't get the stride from idx_row
|
||||||
|
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||||
|
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||||
|
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||||
|
kNWarps * 16);
|
||||||
|
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
|
||||||
|
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
flash::cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
if (n_block > 0) {
|
||||||
|
// Advance gK
|
||||||
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||||
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||||
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||||
|
// isn't right and we get race conditions.
|
||||||
|
cute::cp_async_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||||
|
masking_step == 0
|
||||||
|
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||||
|
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
|
|
||||||
|
// Convert scores from fp32 to fp16/bf16
|
||||||
|
Tensor rP = flash::convert_type<Element>(scores);
|
||||||
|
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||||
|
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||||
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||||
|
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||||
|
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||||
|
if (Return_softmax) {
|
||||||
|
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||||
|
copy(tOrP, tOrP_copy);
|
||||||
|
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||||
|
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||||
|
block_row_idx, block_col_idx, kNWarps
|
||||||
|
);
|
||||||
|
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
||||||
|
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||||
|
}
|
||||||
|
if (Is_dropout) {
|
||||||
|
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||||
|
block_row_idx, block_col_idx, kNWarps);
|
||||||
|
}
|
||||||
|
// if (cute::thread0()) { print(tOrP); }
|
||||||
|
|
||||||
|
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
||||||
|
// if (cute::thread0()) { print(scores); }
|
||||||
|
|
||||||
|
// This check is at the end of the loop since we always have at least 1 iteration
|
||||||
|
if (n_masking_steps > 1 && n_block <= 0) {
|
||||||
|
--n_block;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// These are the iterations where we don't need masking on S
|
||||||
|
for (; n_block >= 0; --n_block) {
|
||||||
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||||
|
clear(acc_s);
|
||||||
|
flash::cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
// Advance gV
|
||||||
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||||
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||||
|
cute::cp_async_fence();
|
||||||
|
|
||||||
|
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||||
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
||||||
|
);
|
||||||
|
|
||||||
|
flash::cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
if (n_block > 0) {
|
||||||
|
// Advance gK
|
||||||
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||||
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||||
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||||
|
// isn't right and we get race conditions.
|
||||||
|
cute::cp_async_fence();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||||
|
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||||
|
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
|
|
||||||
|
Tensor rP = flash::convert_type<Element>(scores);
|
||||||
|
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||||
|
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||||
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||||
|
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||||
|
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||||
|
if (Return_softmax) {
|
||||||
|
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||||
|
copy(tOrP, tOrP_copy);
|
||||||
|
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||||
|
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||||
|
block_row_idx, block_col_idx, kNWarps
|
||||||
|
);
|
||||||
|
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
||||||
|
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||||
|
}
|
||||||
|
if (Is_dropout) {
|
||||||
|
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||||
|
block_row_idx, block_col_idx, kNWarps);
|
||||||
|
}
|
||||||
|
|
||||||
|
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Epilogue
|
||||||
|
|
||||||
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||||
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||||
|
Tensor lse = make_fragment_like(scores_sum);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||||
|
float sum = scores_sum(mi);
|
||||||
|
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||||
|
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||||
|
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||||
|
|
||||||
|
// Convert acc_o from fp32 to fp16/bf16
|
||||||
|
Tensor rO = flash::convert_type<Element>(acc_o);
|
||||||
|
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||||
|
// Partition sO to match the accumulator partitioning
|
||||||
|
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||||
|
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||||
|
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||||
|
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||||
|
|
||||||
|
// sO has the same size as sQ, so we don't need to sync here.
|
||||||
|
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
||||||
|
|
||||||
|
copy(smem_thr_copy_O, taccOrO, taccOsO);
|
||||||
|
|
||||||
|
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||||
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||||
|
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||||
|
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||||
|
make_stride(params.o_row_stride, _1{}));
|
||||||
|
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||||
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||||
|
|
||||||
|
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
|
||||||
|
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||||
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||||
|
copy(gmem_thr_copy_O, tOsO, tOrO);
|
||||||
|
|
||||||
|
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||||
|
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
||||||
|
static_assert(decltype(size<0>(taccOcO))::value == 4);
|
||||||
|
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
|
||||||
|
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||||
|
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||||
|
if (get<1>(taccOcO_row(0)) == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size(lse); ++mi) {
|
||||||
|
const int row = get<0>(taccOcO_row(mi));
|
||||||
|
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct identity layout for sO
|
||||||
|
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||||
|
// Repeat the partitioning with identity layouts
|
||||||
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||||
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
||||||
|
if (!Is_even_K) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||||
|
}
|
||||||
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||||
|
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||||
|
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||||
|
inline __device__ void compute_attn(const Params ¶ms) {
|
||||||
|
const int m_block = blockIdx.x;
|
||||||
|
// The block index for the batch.
|
||||||
|
const int bidb = blockIdx.y;
|
||||||
|
// The block index for the head.
|
||||||
|
const int bidh = blockIdx.z;
|
||||||
|
|
||||||
|
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
|
||||||
|
// them to have the same number of threads or have to traverse the attention matrix
|
||||||
|
// in the same order.
|
||||||
|
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
|
||||||
|
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
|
||||||
|
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||||
|
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||||
|
|
||||||
|
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace flash
|
|
@ -0,0 +1,251 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// #include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "static_switch.h"
|
||||||
|
#include "flash.h"
|
||||||
|
#include "flash_fwd_kernel.h"
|
||||||
|
|
||||||
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
||||||
|
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||||
|
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||||
|
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||||
|
// printf("smem_size = %d\n", smem_size);
|
||||||
|
|
||||||
|
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||||
|
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||||
|
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||||
|
|
||||||
|
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||||
|
dim3 grid(num_m_block, params.b, params.h);
|
||||||
|
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||||
|
// for cu_seqlens_q as well.
|
||||||
|
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||||
|
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||||
|
const bool return_softmax = params.p_ptr != nullptr;
|
||||||
|
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||||
|
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||||
|
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||||
|
// Will only return softmax if dropout, to reduce compilation time.
|
||||||
|
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||||
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||||
|
// if (smem_size >= 48 * 1024) {
|
||||||
|
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
|
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
|
// }
|
||||||
|
int ctas_per_sm;
|
||||||
|
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||||
|
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||||
|
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||||
|
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 32;
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 64;
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
if constexpr(!Is_dropout) {
|
||||||
|
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||||
|
// Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||||
|
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 96;
|
||||||
|
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||||
|
if (is_sm8x) {
|
||||||
|
if constexpr(!Is_causal) {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// These two are always slower
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 128;
|
||||||
|
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
if constexpr(!Is_dropout) {
|
||||||
|
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||||
|
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
||||||
|
if (is_sm8x) {
|
||||||
|
if constexpr(!Is_causal) {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// 1st ones are good for H100, A100
|
||||||
|
// 2nd one is good for A6000 bc we get slightly better occupancy
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 160;
|
||||||
|
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
// For A100, H100, 128 x 32 is the fastest.
|
||||||
|
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||||
|
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
||||||
|
if (is_sm8x) {
|
||||||
|
if constexpr(!Is_causal) {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 192;
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
if constexpr(!Is_dropout) {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 224;
|
||||||
|
int device;
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
int max_smem_per_block;
|
||||||
|
cudaError status_ = cudaDeviceGetAttribute(
|
||||||
|
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||||
|
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
||||||
|
// If we have N = 32, there are only 1024 elements to load at once, where each load
|
||||||
|
// is 8 elements. This means we can only use 128 threads and not 256 threads.
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
constexpr int Headdim = 256;
|
||||||
|
int device;
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
int max_smem_per_sm, max_smem_per_block;
|
||||||
|
cudaError status_ = cudaDeviceGetAttribute(
|
||||||
|
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
|
||||||
|
status_ = cudaDeviceGetAttribute(
|
||||||
|
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||||
|
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||||
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
|
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||||
|
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||||
|
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
} else {
|
||||||
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
}
|
||||||
|
// 64 KB
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
// 96 KB
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,366 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/algorithm/copy.hpp"
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/layout/layout.h"
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||||
|
struct Flash_kernel_traits {
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
using Element = elem_type;
|
||||||
|
static constexpr bool Has_cp_async = true;
|
||||||
|
#else
|
||||||
|
using Element = cutlass::half_t;
|
||||||
|
static constexpr bool Has_cp_async = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using ElementAccum = float;
|
||||||
|
using index_t = uint32_t;
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
using MMA_Atom_Arch = std::conditional_t<
|
||||||
|
std::is_same_v<elem_type, cutlass::half_t>,
|
||||||
|
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||||
|
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||||
|
>;
|
||||||
|
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||||
|
#else
|
||||||
|
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||||
|
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||||
|
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||||
|
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||||
|
#else
|
||||||
|
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
|
||||||
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||||
|
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||||
|
struct Flash_fwd_kernel_traits : public Base {
|
||||||
|
using Element = typename Base::Element;
|
||||||
|
using ElementAccum = typename Base::ElementAccum;
|
||||||
|
using index_t = typename Base::index_t;
|
||||||
|
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||||
|
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||||
|
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||||
|
|
||||||
|
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||||
|
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||||
|
|
||||||
|
// The number of threads.
|
||||||
|
static constexpr int kNWarps = kNWarps_;
|
||||||
|
static constexpr int kNThreads = kNWarps * 32;
|
||||||
|
|
||||||
|
static constexpr int kBlockM = kBlockM_;
|
||||||
|
static constexpr int kBlockN = kBlockN_;
|
||||||
|
static constexpr int kHeadDim = kHeadDim_;
|
||||||
|
static_assert(kHeadDim % 32 == 0);
|
||||||
|
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||||
|
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||||
|
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||||
|
|
||||||
|
using TiledMma = TiledMMA<
|
||||||
|
typename Base::MMA_Atom_Arch,
|
||||||
|
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||||
|
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||||
|
|
||||||
|
using SmemLayoutAtomQ = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||||
|
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutQ = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQ{},
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
|
using SmemLayoutKV = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQ{},
|
||||||
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
|
using SmemLayoutAtomVtransposed = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||||
|
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||||
|
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||||
|
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomVtransposed{},
|
||||||
|
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||||
|
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||||
|
// And the strides don't matter?
|
||||||
|
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||||
|
|
||||||
|
using SmemLayoutAtomO = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutO = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomO{},
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||||
|
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
|
||||||
|
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||||
|
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||||
|
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||||
|
|
||||||
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||||
|
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||||
|
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||||
|
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||||
|
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||||
|
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||||
|
// to the same banks.
|
||||||
|
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||||
|
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||||
|
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||||
|
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||||
|
|
||||||
|
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||||
|
// from the same address by the same threadblock. This is slightly faster.
|
||||||
|
using Gmem_copy_struct = std::conditional_t<
|
||||||
|
Has_cp_async,
|
||||||
|
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||||
|
DefaultCopy
|
||||||
|
>;
|
||||||
|
using GmemTiledCopyQKV = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||||
|
using GmemTiledCopyO = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||||
|
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||||
|
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||||
|
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||||
|
|
||||||
|
using GmemTiledCopyP = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtomP{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||||
|
// No_double_buffer is another option to reduce smem usage, but will slow things down.
|
||||||
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
|
||||||
|
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
|
||||||
|
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
|
||||||
|
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||||
|
struct Flash_bwd_kernel_traits : public Base {
|
||||||
|
using Element = typename Base::Element;
|
||||||
|
using ElementAccum = typename Base::ElementAccum;
|
||||||
|
using index_t = typename Base::index_t;
|
||||||
|
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||||
|
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||||
|
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||||
|
|
||||||
|
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
|
||||||
|
static constexpr bool No_double_buffer = No_double_buffer_;
|
||||||
|
|
||||||
|
// The number of threads.
|
||||||
|
static constexpr int kNWarps = kNWarps_;
|
||||||
|
static constexpr int kNThreads = kNWarps * 32;
|
||||||
|
|
||||||
|
static constexpr int kBlockM = kBlockM_;
|
||||||
|
static constexpr int kBlockN = kBlockN_;
|
||||||
|
static constexpr int kHeadDim = kHeadDim_;
|
||||||
|
static_assert(kHeadDim % 32 == 0);
|
||||||
|
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||||
|
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||||
|
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||||
|
|
||||||
|
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
|
||||||
|
static_assert(kNWarps % AtomLayoutMSdP == 0);
|
||||||
|
static_assert(kNWarps % AtomLayoutNdKV == 0);
|
||||||
|
static_assert(kNWarps % AtomLayoutMdQ == 0);
|
||||||
|
|
||||||
|
using TiledMmaSdP = TiledMMA<
|
||||||
|
typename Base::MMA_Atom_Arch,
|
||||||
|
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||||
|
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||||
|
|
||||||
|
using TiledMmadKV = TiledMMA<
|
||||||
|
typename Base::MMA_Atom_Arch,
|
||||||
|
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||||
|
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||||
|
|
||||||
|
using TiledMmadQ = TiledMMA<
|
||||||
|
typename Base::MMA_Atom_Arch,
|
||||||
|
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||||
|
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||||
|
|
||||||
|
using SmemLayoutAtomQdO = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutQdO = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQdO{},
|
||||||
|
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||||
|
|
||||||
|
using SmemLayoutAtomKV = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutKV = decltype(tile_to_shape(
|
||||||
|
// SmemLayoutAtomQdO{},
|
||||||
|
SmemLayoutAtomKV{},
|
||||||
|
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||||
|
|
||||||
|
using SmemLayoutAtomKtransposed = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||||
|
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||||
|
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomKtransposed{},
|
||||||
|
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||||
|
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||||
|
// And the strides don't matter?
|
||||||
|
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||||
|
|
||||||
|
// TODO: generalize to other values of kBlockN
|
||||||
|
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||||
|
// static constexpr int kPBlockN = kBlockN;
|
||||||
|
static_assert(kBlockN >= 64);
|
||||||
|
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
|
||||||
|
static constexpr int kPBlockN = 64;
|
||||||
|
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
|
||||||
|
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
|
||||||
|
static constexpr int kSwizzlePdS = 3;
|
||||||
|
using SmemLayoutAtomPdS = decltype(
|
||||||
|
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||||
|
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
|
||||||
|
Stride<Int<kPBlockN>, _1>>{}));
|
||||||
|
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomPdS{},
|
||||||
|
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||||
|
using SmemLayoutAtomPdStransposed = decltype(
|
||||||
|
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||||
|
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||||
|
Stride<_1, Int<kPBlockN>>>{}));
|
||||||
|
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomPdStransposed{},
|
||||||
|
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||||
|
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||||
|
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
|
||||||
|
using SmemLayoutAtomQdOtransposed = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||||
|
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||||
|
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQdOtransposed{},
|
||||||
|
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||||
|
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||||
|
|
||||||
|
using SmemLayoutAtomdKV = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomdKV{},
|
||||||
|
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||||
|
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
|
||||||
|
using SmemLayoutAtomdQ = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomdQ{},
|
||||||
|
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||||
|
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
|
||||||
|
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
|
||||||
|
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||||
|
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||||
|
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||||
|
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||||
|
static constexpr int kSmemdPsumCount = kBlockM;
|
||||||
|
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
|
||||||
|
static constexpr int kSmemSize = kSmemQdOSize
|
||||||
|
+ (!Is_V_in_regs
|
||||||
|
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||||
|
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||||
|
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||||
|
+ (!Is_V_in_regs
|
||||||
|
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||||
|
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
|
||||||
|
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
|
||||||
|
+ kSmemdSSize + kSmemPSize;
|
||||||
|
|
||||||
|
|
||||||
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||||
|
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||||
|
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
|
||||||
|
// to affect speed in practice.
|
||||||
|
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||||
|
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||||
|
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||||
|
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||||
|
|
||||||
|
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||||
|
// from the same address by the same threadblock. This is slightly faster.
|
||||||
|
using Gmem_copy_struct = std::conditional_t<
|
||||||
|
Has_cp_async,
|
||||||
|
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||||
|
DefaultCopy
|
||||||
|
>;
|
||||||
|
using GmemTiledCopyQKV = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||||
|
using GmemTiledCopydO = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
using GmemTiledCopydKV = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
using GmemTiledCopydQ = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||||
|
kBlockKSmem == 32,
|
||||||
|
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
|
||||||
|
Stride< _8, _1>>,
|
||||||
|
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
|
||||||
|
Stride< _16, _1>>
|
||||||
|
>;
|
||||||
|
using GmemTiledCopydQaccum = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||||
|
GmemLayoutAtomdQaccum{},
|
||||||
|
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||||
|
|
||||||
|
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||||
|
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||||
|
Stride<_32, _1>>{},
|
||||||
|
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@ -0,0 +1,165 @@
|
||||||
|
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
|
||||||
|
#pragma once
|
||||||
|
// Philox CUDA.
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
struct ull2 {
|
||||||
|
unsigned long long x;
|
||||||
|
unsigned long long y;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||||
|
uint2 *res;
|
||||||
|
unsigned long long tmp;
|
||||||
|
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||||
|
: "=l"(tmp)
|
||||||
|
: "r"(a), "r"(b));
|
||||||
|
res = (uint2*)(&tmp);
|
||||||
|
return *res;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
|
||||||
|
constexpr unsigned long kPhiloxSA = 0xD2511F53;
|
||||||
|
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||||
|
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||||
|
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||||
|
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint4 philox(unsigned long long seed,
|
||||||
|
unsigned long long subsequence,
|
||||||
|
unsigned long long offset) {
|
||||||
|
constexpr unsigned long kPhilox10A = 0x9E3779B9;
|
||||||
|
constexpr unsigned long kPhilox10B = 0xBB67AE85;
|
||||||
|
uint2 key = reinterpret_cast<uint2&>(seed);
|
||||||
|
uint4 counter;
|
||||||
|
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
|
||||||
|
tmp->x = offset;
|
||||||
|
tmp->y = subsequence;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 6; i++) {
|
||||||
|
counter = philox_single_round(counter, key);
|
||||||
|
key.x += (kPhilox10A);
|
||||||
|
key.y += (kPhilox10B);
|
||||||
|
}
|
||||||
|
uint4 output = philox_single_round(counter, key);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace flash
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class Philox {
|
||||||
|
public:
|
||||||
|
__device__ inline Philox(unsigned long long seed,
|
||||||
|
unsigned long long subsequence,
|
||||||
|
unsigned long long offset)
|
||||||
|
: STATE(0)
|
||||||
|
, seed_(seed)
|
||||||
|
, offset_(offset)
|
||||||
|
, key(reinterpret_cast<const uint2&>(seed)) {
|
||||||
|
//key.x = (unsigned int)seed;
|
||||||
|
//key.y = (unsigned int)(seed >> 32);
|
||||||
|
//counter = make_uint4(0, 0, 0, 0);
|
||||||
|
//counter.z = (unsigned int)(subsequence);
|
||||||
|
//counter.w = (unsigned int)(subsequence >> 32);
|
||||||
|
//STATE = 0;
|
||||||
|
//incr_n(offset / 4);
|
||||||
|
|
||||||
|
// key = reinterpret_cast<const uint2&>(seed);
|
||||||
|
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
|
||||||
|
tmp->x = offset / 4;
|
||||||
|
tmp->y = subsequence;
|
||||||
|
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||||
|
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
__device__ inline uint4 operator()() {
|
||||||
|
// // if (STATE == 0) {
|
||||||
|
// uint4 counter_ = counter;
|
||||||
|
// uint2 key_ = key;
|
||||||
|
// // 7-round philox
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int i = 0; i < 6; i++) {
|
||||||
|
// counter_ = flash::philox_single_round(counter_, key_);
|
||||||
|
// key_.x += (kPhilox10A);
|
||||||
|
// key_.y += (kPhilox10B);
|
||||||
|
// }
|
||||||
|
// // output = philox_single_round(counter_, key_);
|
||||||
|
// uint4 output = flash::philox_single_round(counter_, key_);
|
||||||
|
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||||
|
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||||
|
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||||
|
// // }
|
||||||
|
// incr();
|
||||||
|
// // }
|
||||||
|
// // return a float4 directly
|
||||||
|
// // unsigned long ret;
|
||||||
|
// // switch(STATE) {
|
||||||
|
// // case 0: ret = output.x; break;
|
||||||
|
// // case 1: ret = output.y; break;
|
||||||
|
// // case 2: ret = output.z; break;
|
||||||
|
// // case 3: ret = output.w; break;
|
||||||
|
// //}
|
||||||
|
// // STATE = (STATE + 1) % 4;
|
||||||
|
// return output;
|
||||||
|
return flash::philox(seed_, offset_, offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned long long offset_, seed_;
|
||||||
|
struct ull2 {
|
||||||
|
uint64_t x;
|
||||||
|
uint64_t y;
|
||||||
|
};
|
||||||
|
uint4 counter;
|
||||||
|
// uint4 output;
|
||||||
|
const uint2 key;
|
||||||
|
unsigned int STATE;
|
||||||
|
__device__ inline void incr_n(unsigned long long n) {
|
||||||
|
unsigned int nlo = (unsigned int)(n);
|
||||||
|
unsigned int nhi = (unsigned int)(n >> 32);
|
||||||
|
counter.x += nlo;
|
||||||
|
if (counter.x < nlo)
|
||||||
|
nhi++;
|
||||||
|
counter.y += nhi;
|
||||||
|
if (nhi <= counter.y)
|
||||||
|
return;
|
||||||
|
if (++counter.z)
|
||||||
|
return;
|
||||||
|
++counter.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ uint4 incr128 (uint4 ctr)
|
||||||
|
{
|
||||||
|
uint4 res;
|
||||||
|
asm ("add.cc.u32 %0, %4, %8;\n\t"
|
||||||
|
"addc.cc.u32 %1, %5, %9;\n\t"
|
||||||
|
"addc.cc.u32 %2, %6, %10;\n\t"
|
||||||
|
"addc.u32 %3, %7, %11;\n\t"
|
||||||
|
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
|
||||||
|
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
|
||||||
|
"n"(1), "n"(0), "n"(0), "n"(0));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void incr() {
|
||||||
|
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||||
|
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||||
|
// }
|
||||||
|
counter = incr128(counter);
|
||||||
|
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||||
|
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||||
|
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||||
|
// static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||||
|
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
|
@ -0,0 +1,272 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
|
||||||
|
#include "philox.cuh"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||||
|
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||||
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||||
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||||
|
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||||
|
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||||
|
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||||
|
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size(dst); i++){
|
||||||
|
dst(i) = Allreduce<4>::run(src(i), op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||||
|
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||||
|
thread_reduce_<zero_init>(tensor, summary, op);
|
||||||
|
quad_allreduce_(summary, summary, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
|
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||||
|
MaxOp<float> max_op;
|
||||||
|
reduce_<zero_init>(tensor, max, max_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
|
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||||
|
SumOp<float> sum_op;
|
||||||
|
reduce_(tensor, sum, sum_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the exp to all the elements.
|
||||||
|
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
|
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||||
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||||
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||||
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||||
|
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||||
|
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||||
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||||
|
// max * log_2(e)) This allows the compiler to use the ffma
|
||||||
|
// instruction instead of fadd and fmul separately.
|
||||||
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the exp to all the elements.
|
||||||
|
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
|
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||||
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||||
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
|
MaxOp<float> max_op;
|
||||||
|
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||||
|
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||||
|
}
|
||||||
|
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||||
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||||
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||||
|
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||||
|
sum(mi) = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||||
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||||
|
// max * log_2(e)) This allows the compiler to use the ffma
|
||||||
|
// instruction instead of fadd and fmul separately.
|
||||||
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||||
|
sum(mi) += tensor(mi, ni);
|
||||||
|
}
|
||||||
|
SumOp<float> sum_op;
|
||||||
|
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Engine, typename Layout>
|
||||||
|
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
|
||||||
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
|
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||||
|
const uint32_t lane_id = threadIdx.x % 32;
|
||||||
|
#pragma unroll
|
||||||
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
|
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
|
||||||
|
if (col_idx >= max_seqlen_k) {
|
||||||
|
// Without the "make_coord" we get wrong results
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
|
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Engine, typename Layout>
|
||||||
|
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
||||||
|
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
|
||||||
|
const uint32_t warp_row_stride) {
|
||||||
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
|
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||||
|
const uint32_t lane_id = threadIdx.x % 32;
|
||||||
|
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||||
|
const uint32_t row_idx_offset = row_idx_offset_;
|
||||||
|
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||||
|
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||||
|
const uint32_t row_idx = row_idx_base + i * 8;
|
||||||
|
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
|
||||||
|
#pragma unroll
|
||||||
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
|
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
|
const uint32_t col_idx = col_idx_base + j;
|
||||||
|
if (col_idx >= col_idx_limit) {
|
||||||
|
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if (cute::thread0()) {
|
||||||
|
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||||
|
// print(tensor(make_coord(i, mi), _));
|
||||||
|
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
|
inline __device__ void apply_mask_causal_w_idx(
|
||||||
|
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||||
|
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
||||||
|
{
|
||||||
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||||
|
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
|
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||||
|
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||||
|
tensor(mi, ni) = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if (cute::thread0()) {
|
||||||
|
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||||
|
// print(tensor(_, make_coord(j, ni)));
|
||||||
|
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||||
|
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||||
|
unsigned long long seed, unsigned long long offset,
|
||||||
|
uint32_t block_row_start, uint32_t block_col_start,
|
||||||
|
uint32_t block_row_stride) {
|
||||||
|
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||||
|
using T = typename Engine::value_type;
|
||||||
|
auto encode_dropout = [](bool keep, T val) {
|
||||||
|
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||||
|
};
|
||||||
|
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||||
|
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||||
|
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||||
|
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||||
|
#pragma unroll
|
||||||
|
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||||
|
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||||
|
#pragma unroll
|
||||||
|
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||||
|
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||||
|
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||||
|
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||||
|
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||||
|
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||||
|
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||||
|
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||||
|
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||||
|
// the random value is less than the threshold.
|
||||||
|
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||||
|
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||||
|
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||||
|
if (!encode_dropout_in_sign_bit
|
||||||
|
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||||
|
uint16_t rnd_16[16];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||||
|
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||||
|
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||||
|
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
uint32_t mask;
|
||||||
|
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||||
|
tensor_uint32(i) &= mask;
|
||||||
|
}
|
||||||
|
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||||
|
}
|
||||||
|
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||||
|
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||||
|
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||||
|
// // }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace flash
|
|
@ -0,0 +1,66 @@
|
||||||
|
// Inspired by
|
||||||
|
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||||
|
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/// @param COND - a boolean expression to switch by
|
||||||
|
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||||
|
/// @param ... - code to execute for true and false
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```
|
||||||
|
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||||
|
/// some_function<BoolConst>(...);
|
||||||
|
/// });
|
||||||
|
/// ```
|
||||||
|
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||||
|
[&] { \
|
||||||
|
if (COND) { \
|
||||||
|
constexpr static bool CONST_NAME = true; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
constexpr static bool CONST_NAME = false; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
|
|
||||||
|
#define FP16_SWITCH(COND, ...) \
|
||||||
|
[&] { \
|
||||||
|
if (COND) { \
|
||||||
|
using elem_type = cutlass::half_t; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else { \
|
||||||
|
using elem_type = cutlass::bfloat16_t; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
|
|
||||||
|
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
|
||||||
|
[&] { \
|
||||||
|
if (HEADDIM <= 32) { \
|
||||||
|
constexpr static int kHeadDim = 32; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 64) { \
|
||||||
|
constexpr static int kHeadDim = 64; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 96) { \
|
||||||
|
constexpr static int kHeadDim = 96; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 128) { \
|
||||||
|
constexpr static int kHeadDim = 128; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 160) { \
|
||||||
|
constexpr static int kHeadDim = 160; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 192) { \
|
||||||
|
constexpr static int kHeadDim = 192; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 224) { \
|
||||||
|
constexpr static int kHeadDim = 224; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} else if (HEADDIM <= 256) { \
|
||||||
|
constexpr static int kHeadDim = 256; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
}()
|
|
@ -0,0 +1,388 @@
|
||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <cute/algorithm/copy.hpp>
|
||||||
|
#include <cute/algorithm/gemm.hpp>
|
||||||
|
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/numeric_conversion.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ uint32_t relu2(const uint32_t x);
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||||
|
uint32_t res;
|
||||||
|
const uint32_t zero = 0u;
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||||
|
#else
|
||||||
|
asm volatile( \
|
||||||
|
"{\n" \
|
||||||
|
"\t .reg .f16x2 sela;\n" \
|
||||||
|
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
|
||||||
|
"\t and.b32 %0, sela, %1;\n"
|
||||||
|
"}\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||||
|
#endif
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
template<>
|
||||||
|
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||||
|
uint32_t res;
|
||||||
|
const uint32_t zero = 0u;
|
||||||
|
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ uint32_t convert_relu2(const float2 x);
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||||
|
uint32_t res;
|
||||||
|
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||||
|
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||||
|
asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||||
|
uint32_t res;
|
||||||
|
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||||
|
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||||
|
asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ float2 half2_unpack(uint32_t a);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
||||||
|
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
template <>
|
||||||
|
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
||||||
|
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Convert two half2's or bf162's into float, then take their dot product.
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
||||||
|
float2 af = flash::half2_unpack<T>(a);
|
||||||
|
float2 bf = flash::half2_unpack<T>(b);
|
||||||
|
return af.x * bf.x + af.y * bf.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||||
|
float sum;
|
||||||
|
sum = flash::hfma2_to_float<T>(a.x, b.x);
|
||||||
|
sum += flash::hfma2_to_float<T>(a.y, b.y);
|
||||||
|
sum += flash::hfma2_to_float<T>(a.z, b.z);
|
||||||
|
sum += flash::hfma2_to_float<T>(a.w, b.w);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct MaxOp {
|
||||||
|
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct MaxOp<float> {
|
||||||
|
// This is slightly faster
|
||||||
|
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct SumOp {
|
||||||
|
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<int THREADS>
|
||||||
|
struct Allreduce {
|
||||||
|
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||||
|
template<typename T, typename Operator>
|
||||||
|
static __device__ inline T run(T x, Operator &op) {
|
||||||
|
constexpr int OFFSET = THREADS / 2;
|
||||||
|
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||||
|
return Allreduce<OFFSET>::run(x, op);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Allreduce<2> {
|
||||||
|
template<typename T, typename Operator>
|
||||||
|
static __device__ inline T run(T x, Operator &op) {
|
||||||
|
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||||
|
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||||
|
typename TiledMma, typename TiledCopy0, typename TiledCopy1>
|
||||||
|
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||||
|
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||||
|
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||||
|
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||||
|
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||||
|
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||||
|
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||||
|
if (i < size<2>(tCrA) - 1) {
|
||||||
|
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||||
|
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||||
|
}
|
||||||
|
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||||
|
typename TiledMma, typename TiledCopy>
|
||||||
|
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||||
|
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||||
|
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||||
|
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||||
|
if (i < size<2>(tCrA) - 1) {
|
||||||
|
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||||
|
}
|
||||||
|
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||||
|
template<typename Layout>
|
||||||
|
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||||
|
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||||
|
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||||
|
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||||
|
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||||
|
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||||
|
template<typename MMA_traits, typename Layout>
|
||||||
|
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||||
|
using X = Underscore;
|
||||||
|
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
|
||||||
|
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
|
||||||
|
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||||
|
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||||
|
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||||
|
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||||
|
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||||
|
get<0, 1>(l),
|
||||||
|
get<1, 1, 1>(l));
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename To_type, typename Engine, typename Layout>
|
||||||
|
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||||
|
using From_type = typename Engine::value_type;
|
||||||
|
constexpr int numel = decltype(size(tensor))::value;
|
||||||
|
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||||
|
// HACK: this requires tensor to be "contiguous"
|
||||||
|
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||||
|
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename Engine, typename Layout>
|
||||||
|
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||||
|
constexpr int numel = decltype(size(tensor))::value;
|
||||||
|
static_assert(numel % 2 == 0);
|
||||||
|
using value_t = typename Engine::value_type;
|
||||||
|
// HACK: this requires tensor to be "contiguous"
|
||||||
|
Tensor tensor_uint32 = recast<uint32_t>(tensor);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size(tensor_uint32); ++i) {
|
||||||
|
tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
|
||||||
|
template <typename To_type, typename Engine, typename Layout>
|
||||||
|
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||||
|
using From_type = typename Engine::value_type;
|
||||||
|
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
|
||||||
|
static_assert(std::is_same_v<float, From_type>);
|
||||||
|
constexpr int numel = decltype(size(tensor))::value;
|
||||||
|
static_assert(numel % 2 == 0);
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
// HACK: this requires tensor to be "contiguous"
|
||||||
|
Tensor tensor_float2 = recast<float2>(tensor);
|
||||||
|
Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size(out_uint32); ++i) {
|
||||||
|
out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
|
||||||
|
}
|
||||||
|
Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
|
||||||
|
#else
|
||||||
|
Tensor out = flash::convert_type<To_type>(tensor);
|
||||||
|
flash::relu_(out);
|
||||||
|
#endif
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||||
|
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
||||||
|
// (which is equivalent to commit_group then wait_group 0).
|
||||||
|
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
||||||
|
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
||||||
|
template <int N>
|
||||||
|
CUTE_HOST_DEVICE
|
||||||
|
void cp_async_wait() {
|
||||||
|
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||||
|
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||||
|
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||||
|
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
|
||||||
|
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||||
|
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
|
||||||
|
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||||
|
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||||
|
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||||
|
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||||
|
#pragma unroll
|
||||||
|
for (int m = 0; m < size<1>(S); ++m) {
|
||||||
|
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < size<2>(S); ++k) {
|
||||||
|
if (Is_even_K || predicate_K(k)) {
|
||||||
|
copy(thr_copy, S(_, m, k), D(_, m, k));
|
||||||
|
} else if (Clear_OOB_K) {
|
||||||
|
clear(D(_, m, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (Clear_OOB_MN) {
|
||||||
|
clear(D(_, m, _));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TD [2023-04-13]: Strange that the code below can cause race condition.
|
||||||
|
// I think it's because the copies are under an if statement.
|
||||||
|
// if (Is_even_K) {
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int m = 0; m < size<1>(S); ++m) {
|
||||||
|
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||||
|
// copy(thr_copy, S(_, m, _), D(_, m, _));
|
||||||
|
// } else if (Clear_OOB_MN) {
|
||||||
|
// clear(D(_, m, _));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// } else { // It's slightly faster in this case if iterate over K first
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int k = 0; k < size<2>(S); ++k) {
|
||||||
|
// if (predicate_K(k)) {
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int m = 0; m < size<1>(S); ++m) {
|
||||||
|
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||||
|
// copy(thr_copy, S(_, m, k), D(_, m, k));
|
||||||
|
// } else if (Clear_OOB_MN) {
|
||||||
|
// clear(D(_, m, k));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||||
|
// if (Clear_OOB_MN || Is_even_MN) {
|
||||||
|
// clear(D(_, _, k));
|
||||||
|
// } else {
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int m = 0; m < size<1>(S); ++m) {
|
||||||
|
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
|
||||||
|
// clear(D(_, m, k));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace flash
|
|
@ -0,0 +1,35 @@
|
||||||
|
use core::ffi::{c_int, c_void};
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
pub(crate) fn run_mha(
|
||||||
|
q_ptr: *const c_void,
|
||||||
|
k_ptr: *const c_void,
|
||||||
|
v_ptr: *const c_void,
|
||||||
|
o_ptr: *const c_void,
|
||||||
|
|
||||||
|
q_batch_stride: u32,
|
||||||
|
k_batch_stride: u32,
|
||||||
|
v_batch_stride: u32,
|
||||||
|
q_row_stride: u32,
|
||||||
|
k_row_stride: u32,
|
||||||
|
v_row_stride: u32,
|
||||||
|
q_head_stride: u32,
|
||||||
|
k_head_stride: u32,
|
||||||
|
v_head_stride: u32,
|
||||||
|
|
||||||
|
b: u32,
|
||||||
|
h: u32,
|
||||||
|
h_k: u32,
|
||||||
|
d: u32,
|
||||||
|
d_rounded: u32,
|
||||||
|
softmax_scale: f32,
|
||||||
|
|
||||||
|
seqlen_q: u32,
|
||||||
|
seqlen_k: u32,
|
||||||
|
seqlen_q_rounded: u32,
|
||||||
|
seqlen_k_rounded: u32,
|
||||||
|
|
||||||
|
is_causal: c_int,
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
mod ffi;
|
||||||
|
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle::cuda_backend::WrapErr;
|
||||||
|
use candle::{CpuStorage, Error, Layout, Result, Shape};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
|
pub struct FlashHdim32Sm80;
|
||||||
|
|
||||||
|
impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"flash-hdim32-sm80"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
Err(Error::Wrapped("no cpu support for flash-attn".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
q: &candle::CudaStorage,
|
||||||
|
_q_l: &Layout,
|
||||||
|
k: &candle::CudaStorage,
|
||||||
|
_k_l: &Layout,
|
||||||
|
v: &candle::CudaStorage,
|
||||||
|
_v_l: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
let dev = q.device();
|
||||||
|
let out_shape = Shape::from(&[1]);
|
||||||
|
let q = q.as_cuda_slice::<f16>()?;
|
||||||
|
let k = k.as_cuda_slice::<f16>()?;
|
||||||
|
let v = v.as_cuda_slice::<f16>()?;
|
||||||
|
let elem_count = out_shape.elem_count();
|
||||||
|
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||||
|
ffi::run_mha(
|
||||||
|
q_ptr, k_ptr, v_ptr, dst_ptr, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 1, 1,
|
||||||
|
1, 1, 1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone());
|
||||||
|
Ok((dst, out_shape))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue