Revert "Add the layer norm files. (#222)" (#223)

This reverts commit c8459d199d.
This commit is contained in:
Laurent Mazare 2023-07-22 17:51:11 +02:00 committed by GitHub
parent c8459d199d
commit 5f20acf080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 0 additions and 1532 deletions

View File

@ -2,7 +2,3 @@
This crate contains CUDA kernels used from candle. Some of these implementations
come from the [dfdx crate](https://github.com/coreylowman/dfdx).
The `ln*` files come from the [flash-attention
repo](https://github.com/Dao-AILab/flash-attention) and have been edited so as
to compile without including the PyTorch codebase.

View File

@ -184,7 +184,6 @@ mod cuda {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.arg("--expt-relaxed-constexpr")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only

View File

@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const LN_FWD_256: &str = include_str!(concat!(env!("OUT_DIR"), "/ln_fwd_256.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -1,274 +0,0 @@
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdint.h>
#include <functional>
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t elts_per_thread;
size_t workspace_bytes;
size_t barrier_size;
cudaDeviceProp * props;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, gamma1(nullptr)
, rowscale(nullptr)
, colscale(nullptr)
, dropout_keep_p(1.f)
, dropout_scale(1.f)
, is_rms_norm(false)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x0;
void *x1;
void *residual;
void *x;
void *dmask;
void *dmask1;
void *mu;
void *rs;
void *gamma;
void *gamma1;
void *rowscale;
void *colscale;
void *x0_subset;
void *z_subset;
float inverse_cols;
float dropout_keep_p;
float dropout_scale;
float rowscale_const;
bool is_rms_norm;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, z1(nullptr)
, beta(nullptr)
, beta1(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *z1;
void *beta;
void *beta1;
float epsilon;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dz1(nullptr)
, dx(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dbeta1_part(nullptr)
, dgamma1_part(nullptr)
, dcolscale_part(nullptr)
, dx0(nullptr)
, dx1(nullptr)
, dresidual(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
, dbeta1(nullptr)
, dgamma1(nullptr)
, dcolscale(nullptr)
{
}
// Input: gradient wrt. LN FWD output.
void *dz;
void *dz1;
// Input: gradient wrt residual.
void *dx;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
void *dbeta1_part;
void *dgamma1_part;
void *dcolscale_part;
// Output: Dgrad.
void *dx0;
void *dx1;
void *dresidual;
// Output: Wgrad.
void *dbeta;
void *dgamma;
void *dbeta1;
void *dgamma1;
void *dcolscale;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct ResidualType2Key : public Type2Key<T, 4>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 6>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 8>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
BwdRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdParallelRegistrar{
FwdParallelRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
PARALLEL_FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdParallelRegistrar{
BwdParallelRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
PARALLEL_BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm

View File

@ -1,15 +0,0 @@
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);

View File

@ -1,257 +0,0 @@
#pragma once
#include <curand_kernel.h>
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using input_t = typename Ktraits::input_t;
using residual_t = typename Ktraits::residual_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using mask_t = typename Ktraits::mask_t;
using Ivec = typename Ktraits::Ivec;
using Rvec = typename Ktraits::Rvec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Mvec = typename Ktraits::Mvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
const bool has_residual = params.residual != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
Wvec gamma[LDGS];
Wvec beta[LDGS];
Wvec colscale[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx);
if (params.beta != nullptr) {
beta[it].load_from(params.beta, idx);
} else {
beta[it].zero_();
}
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += VEC_COLS_PER_LDG;
}
}
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
const int row_z = !Has_subset ? row + 1 : z_subset[row];
const bool load_x0 = !Has_subset || row_x0 > 0;
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec x0;
Rvec residual;
Rvec x;
Mvec dmask;
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (has_residual) { residual.load_from(params.residual, idx_x); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
compute_t x_ij;
if (load_x0) {
mask_t keep = true;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
} else {
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
}
if (save_x) { x.store_to(params.x, idx_x); }
if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx_x += VEC_COLS_PER_LDG;
idx_x0 += VEC_COLS_PER_LDG;
}
}
static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
// Need to convert to int, otherwise the subtraction will wrap around.
const index_t valid_partial_vecs_in_warp =
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
int(THREADS_PER_WARP));
return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
};
stats_t s = stats.template compute<Is_even_cols>(
xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
const bool save_z = !Has_subset || row_z > 0;
if (save_z) {
index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ovec z;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
}
z.store_to(params.z, idx_z);
idx_z += VEC_COLS_PER_LDG;
}
}
}
}
}
} // namespace layer_norm
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
bool has_colscale = launch_params.params.colscale != nullptr;
bool has_subset = launch_params.params.x0_subset != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
});
});
}

View File

@ -1,172 +0,0 @@
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using residual_t = residual_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
bool Has_colscale,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
residual_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
residual_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using residual_t = typename Base::residual_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
// using mask_t = unsigned char;
using mask_t = bool;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) == sizeof(output_t));
static_assert(sizeof(input_t) <= sizeof(residual_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm

View File

@ -1,783 +0,0 @@
#pragma once
#include <cassert>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "ln.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline void check_cuda_(cudaError_t status, const char *file, int line) {
if( status != cudaSuccess ) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
exit(status);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(ans) \
{ check_cuda_((ans), __FILE__, __LINE__); }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, \
ITYPE, \
RTYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_PARALLEL_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_parallel_residual_<WTYPE, \
ITYPE, \
RTYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 & a, const float2 & b){
return {a.x + b.x, a.y + b.y};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b){
a.x += b.x;
a.y += b.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Sum {
inline __device__ Sum(){}
inline __device__ T operator()(const T &a, const T &b){
return a + b;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
return __shfl_xor_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
}
template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
return __shfl_down_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 {
uint4 u;
uint4 v;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeToVec2 {};
template<>
struct TypeToVec2<float> {
using Type = float2;
};
template<>
struct TypeToVec2<half> {
using Type = half2;
};
template<>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX>
struct Get {
template<typename T, typename R>
static inline __device__ R of(const T &vec);
};
template<>
template<typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
return vec.x;
}
template<>
template<typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
return vec.y;
}
template<>
template<typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
return vec.z;
}
template<>
template<typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
return vec.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst>
struct Converter{
static inline __device__ Dst convert(const Src &from) {
return Dst(from);
}
};
template<>
struct Converter<float2, half2>{
static inline __device__ half2 convert(const float2 &x) {
return __float22half2_rn(x);
}
};
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 x;
nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(x.x);
tmp.y = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Zeros{
static inline __device__ T get() {
return T(0.f);
}
};
template<>
struct Zeros<float2>{
static inline __device__ float2 get() {
return make_float2(0.f, 0.f);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void zero_() {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = Elt_type(0.f);
}
}
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
inline __device__ void store_to(void *base_ptr, const size_t idx) {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<uint32_t CTAS_PER_ROW>
struct InterCTASync {
template<typename Params>
inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
: phase_counter_(0)
, b0_(params.barrier + bidm) // The barrier for this group of CTAs.
, b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
{
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
inline __device__ void sync(){
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : CTAS_PER_ROW;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if( threadIdx.x == 0 ) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
int phase_counter_;
int * b0_;
int * b1_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params, bidm, bidn)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
{
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if(this->lane_ < CTAS_PER_ROW){
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T *w0_;
T *w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane)
{
}
template<typename Op>
static inline __device__ T allreduce_(T data, Op &op) {
#pragma unroll
for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
return allreduce_(data, op);
}
template<typename Op>
inline __device__ T reduce(T data, Op &op){
// only lane 0 holds the result!
#pragma unroll
for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it));
}
return data;
}
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<typename Op>
inline __device__ T allreduce(T data, Op & op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op>
inline __device__ T reduce(T data, Op &op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
}
T * smem0_;
T * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, typename int_t>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange
int_t n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params, bidm, bidn)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if( lane_ < CTAS_PER_ROW ) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return { m, m2 };
}
InterCTASync inter_cta_;
BlockStats block_stats_;
stats_t *w0_;
stats_t *w1_;
int bidn_;
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<bool Is_even_cols, uint32_t N, typename function_t>
inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
const auto warp_n = warp_stats_.reducer_.warp_n_;
const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));
stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(
elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts
);
//Each warp warp leader stores its stats
const auto lane = warp_stats_.reducer_.lane_;
if( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
int n = 0;;
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if(lane < WARPS_N){
stats_t result = smem[lane];
n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return { m, m2 };
}
WarpStats warp_stats_;
stats_t * smem0_;
stats_t * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
{
}
template<bool Is_even_cols, uint32_t N, typename function_t>
inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
// const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
auto sum = Sum<T>();
T m = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
if (Is_even_cols || (it < num_valid_elts)) {
m += elts[it];
}
}
m = reducer_.allreduce(m, sum) * row_norm_factor;
T m2 = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
if (Is_even_cols || (it < num_valid_elts)) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
}
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm

View File

@ -1,25 +0,0 @@
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()