This reverts commit c8459d199d
.
This commit is contained in:
parent
c8459d199d
commit
5f20acf080
|
@ -2,7 +2,3 @@
|
|||
|
||||
This crate contains CUDA kernels used from candle. Some of these implementations
|
||||
come from the [dfdx crate](https://github.com/coreylowman/dfdx).
|
||||
|
||||
The `ln*` files come from the [flash-attention
|
||||
repo](https://github.com/Dao-AILab/flash-attention) and have been edited so as
|
||||
to compile without including the PyTorch codebase.
|
||||
|
|
|
@ -184,7 +184,6 @@ mod cuda {
|
|||
let mut command = std::process::Command::new("nvcc");
|
||||
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
|
|
|
@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
|||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const LN_FWD_256: &str = include_str!(concat!(env!("OUT_DIR"), "/ln_fwd_256.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
|
|
@ -1,274 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <stdint.h>
|
||||
#include <functional>
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Params>
|
||||
struct LaunchParams{
|
||||
|
||||
size_t elts_per_thread;
|
||||
size_t workspace_bytes;
|
||||
size_t barrier_size;
|
||||
|
||||
cudaDeviceProp * props;
|
||||
|
||||
cudaStream_t stream;
|
||||
|
||||
Params params;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ParamsBase {
|
||||
ParamsBase()
|
||||
: ctas_per_col(0)
|
||||
, rows(0)
|
||||
, cols(0)
|
||||
, x(nullptr)
|
||||
, mu(nullptr)
|
||||
, rs(nullptr)
|
||||
, gamma(nullptr)
|
||||
, gamma1(nullptr)
|
||||
, rowscale(nullptr)
|
||||
, colscale(nullptr)
|
||||
, dropout_keep_p(1.f)
|
||||
, dropout_scale(1.f)
|
||||
, is_rms_norm(false)
|
||||
, workspace(nullptr)
|
||||
, barrier(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
|
||||
int ctas_per_col;
|
||||
|
||||
// Input is interpreted as matrix. We normalize across columns.
|
||||
int rows;
|
||||
int cols;
|
||||
|
||||
// Common data pointers.
|
||||
void *x0;
|
||||
void *x1;
|
||||
void *residual;
|
||||
void *x;
|
||||
void *dmask;
|
||||
void *dmask1;
|
||||
void *mu;
|
||||
void *rs;
|
||||
void *gamma;
|
||||
void *gamma1;
|
||||
void *rowscale;
|
||||
void *colscale;
|
||||
void *x0_subset;
|
||||
void *z_subset;
|
||||
|
||||
float inverse_cols;
|
||||
|
||||
float dropout_keep_p;
|
||||
float dropout_scale;
|
||||
float rowscale_const;
|
||||
|
||||
bool is_rms_norm;
|
||||
|
||||
// Multi-CTA workspace in gmem.
|
||||
void *workspace;
|
||||
|
||||
// Multi-CTA sync barriers in gmem.
|
||||
int *barrier;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FwdParams : public ParamsBase {
|
||||
FwdParams()
|
||||
: ParamsBase()
|
||||
, z(nullptr)
|
||||
, z1(nullptr)
|
||||
, beta(nullptr)
|
||||
, beta1(nullptr)
|
||||
, epsilon(0.f)
|
||||
{
|
||||
}
|
||||
|
||||
// Output of LN FWD.
|
||||
void *z;
|
||||
void *z1;
|
||||
void *beta;
|
||||
void *beta1;
|
||||
float epsilon;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct BwdParams : public ParamsBase {
|
||||
BwdParams()
|
||||
: ParamsBase()
|
||||
, dz(nullptr)
|
||||
, dz1(nullptr)
|
||||
, dx(nullptr)
|
||||
, dbeta_part(nullptr)
|
||||
, dgamma_part(nullptr)
|
||||
, dbeta1_part(nullptr)
|
||||
, dgamma1_part(nullptr)
|
||||
, dcolscale_part(nullptr)
|
||||
, dx0(nullptr)
|
||||
, dx1(nullptr)
|
||||
, dresidual(nullptr)
|
||||
, dbeta(nullptr)
|
||||
, dgamma(nullptr)
|
||||
, dbeta1(nullptr)
|
||||
, dgamma1(nullptr)
|
||||
, dcolscale(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
// Input: gradient wrt. LN FWD output.
|
||||
void *dz;
|
||||
void *dz1;
|
||||
// Input: gradient wrt residual.
|
||||
void *dx;
|
||||
|
||||
// Workspace for Wgrad pre-reduction.
|
||||
void *dbeta_part;
|
||||
void *dgamma_part;
|
||||
void *dbeta1_part;
|
||||
void *dgamma1_part;
|
||||
void *dcolscale_part;
|
||||
|
||||
// Output: Dgrad.
|
||||
void *dx0;
|
||||
void *dx1;
|
||||
void *dresidual;
|
||||
// Output: Wgrad.
|
||||
void *dbeta;
|
||||
void *dgamma;
|
||||
void *dbeta1;
|
||||
void *dgamma1;
|
||||
void *dcolscale;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
|
||||
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
|
||||
using FunctionKey = uint64_t;
|
||||
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
|
||||
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
|
||||
|
||||
extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
|
||||
extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using fp32 = float;
|
||||
using fp16 = half;
|
||||
using bf16 = nv_bfloat16;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct TypeId{};
|
||||
|
||||
template<>
|
||||
struct TypeId<fp16>{
|
||||
constexpr static uint32_t Value = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeId<bf16>{
|
||||
constexpr static uint32_t Value = 1;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeId<fp32>{
|
||||
constexpr static uint32_t Value = 2;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int S>
|
||||
struct Type2Key{
|
||||
constexpr static uint32_t Value = TypeId<T>::Value << S;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct WeightType2Key : public Type2Key<T, 0>{};
|
||||
|
||||
template<typename T>
|
||||
struct InputType2Key : public Type2Key<T, 2>{};
|
||||
|
||||
template<typename T>
|
||||
struct ResidualType2Key : public Type2Key<T, 4>{};
|
||||
|
||||
template<typename T>
|
||||
struct OutputType2Key : public Type2Key<T, 6>{};
|
||||
|
||||
template<typename T>
|
||||
struct ComputeType2Key : public Type2Key<T, 8>{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C>
|
||||
struct Types2Key{
|
||||
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
|
||||
constexpr static inline uint64_t get(const uint64_t hidden_size){
|
||||
constexpr uint64_t type_key = Value;
|
||||
return (type_key << 32) | hidden_size;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
||||
struct FwdRegistrar{
|
||||
FwdRegistrar(FwdFunction f){
|
||||
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
||||
FWD_FUNCS.insert({ key, f });
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
||||
struct BwdRegistrar{
|
||||
BwdRegistrar(BwdFunction f){
|
||||
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
||||
BWD_FUNCS.insert({ key, f });
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
||||
struct FwdParallelRegistrar{
|
||||
FwdParallelRegistrar(FwdFunction f){
|
||||
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
||||
PARALLEL_FWD_FUNCS.insert({ key, f });
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
||||
struct BwdParallelRegistrar{
|
||||
BwdParallelRegistrar(BwdFunction f){
|
||||
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
||||
PARALLEL_BWD_FUNCS.insert({ key, f });
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layer_norm
|
|
@ -1,15 +0,0 @@
|
|||
#include "ln_fwd_kernels.cuh"
|
||||
|
||||
// Create forward launch function and register. Macro signature:
|
||||
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
|
@ -1,257 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "ln.h"
|
||||
#include "ln_utils.cuh"
|
||||
#include "ln_kernel_traits.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
|
||||
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
||||
void ln_fwd_kernel(FwdParams params) {
|
||||
|
||||
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
|
||||
enum { WARPS_N = Ktraits::WARPS_N };
|
||||
enum { WARPS_M = Ktraits::WARPS_M };
|
||||
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
|
||||
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
|
||||
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
|
||||
enum { LDGS = Ktraits::LDGS };
|
||||
enum { NUM_ELTS = Ktraits::NUM_ELTS };
|
||||
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
|
||||
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using residual_t = typename Ktraits::residual_t;
|
||||
using output_t = typename Ktraits::output_t;
|
||||
using index_t = typename Ktraits::index_t;
|
||||
using compute_t = typename Ktraits::compute_t;
|
||||
using mask_t = typename Ktraits::mask_t;
|
||||
using Ivec = typename Ktraits::Ivec;
|
||||
using Rvec = typename Ktraits::Rvec;
|
||||
using Ovec = typename Ktraits::Ovec;
|
||||
using Wvec = typename Ktraits::Wvec;
|
||||
using Cvec = typename Ktraits::Cvec;
|
||||
using Mvec = typename Ktraits::Mvec;
|
||||
|
||||
using Stats = typename Ktraits::Stats;
|
||||
using stats_t = typename Stats::stats_t;
|
||||
|
||||
const bool has_residual = params.residual != nullptr;
|
||||
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const index_t tidx = threadIdx.x;
|
||||
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
|
||||
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
|
||||
const index_t lane = tidx % THREADS_PER_WARP;
|
||||
const index_t warp = tidx / THREADS_PER_WARP;
|
||||
const index_t warp_m = warp / WARPS_N;
|
||||
const index_t warp_n = warp % WARPS_N;
|
||||
|
||||
const index_t r = bidm * ROWS_PER_CTA + warp_m;
|
||||
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
|
||||
|
||||
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
|
||||
|
||||
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
|
||||
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
|
||||
|
||||
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
|
||||
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
|
||||
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
|
||||
|
||||
const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
|
||||
|
||||
Wvec gamma[LDGS];
|
||||
Wvec beta[LDGS];
|
||||
Wvec colscale[LDGS];
|
||||
index_t idx = c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
gamma[it].load_from(params.gamma, idx);
|
||||
if (params.beta != nullptr) {
|
||||
beta[it].load_from(params.beta, idx);
|
||||
} else {
|
||||
beta[it].zero_();
|
||||
}
|
||||
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
||||
idx += VEC_COLS_PER_LDG;
|
||||
}
|
||||
}
|
||||
|
||||
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
|
||||
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
|
||||
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
|
||||
const int row_z = !Has_subset ? row + 1 : z_subset[row];
|
||||
const bool load_x0 = !Has_subset || row_x0 > 0;
|
||||
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
|
||||
index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
|
||||
compute_t xf[LDGS * NUM_ELTS];
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
Ivec x0;
|
||||
Rvec residual;
|
||||
Rvec x;
|
||||
Mvec dmask;
|
||||
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
|
||||
if (has_residual) { residual.load_from(params.residual, idx_x); }
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
|
||||
// the more efficient curand_uniform4.
|
||||
compute_t x_ij;
|
||||
if (load_x0) {
|
||||
mask_t keep = true;
|
||||
if (Is_dropout) { dmask.data.elt[jt] = keep; }
|
||||
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
|
||||
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
|
||||
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
|
||||
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
|
||||
} else {
|
||||
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
|
||||
}
|
||||
if (save_x) { x.data.elt[jt] = x_ij; }
|
||||
xf[it * NUM_ELTS + jt] = x_ij;
|
||||
}
|
||||
if (save_x) { x.store_to(params.x, idx_x); }
|
||||
if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
|
||||
idx_x += VEC_COLS_PER_LDG;
|
||||
idx_x0 += VEC_COLS_PER_LDG;
|
||||
}
|
||||
}
|
||||
|
||||
static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
|
||||
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
|
||||
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
|
||||
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
|
||||
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
|
||||
// Need to convert to int, otherwise the subtraction will wrap around.
|
||||
const index_t valid_partial_vecs_in_warp =
|
||||
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
|
||||
int(THREADS_PER_WARP));
|
||||
return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
|
||||
};
|
||||
stats_t s = stats.template compute<Is_even_cols>(
|
||||
xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
|
||||
);
|
||||
|
||||
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
|
||||
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
|
||||
|
||||
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
|
||||
mu_ptr[row] = mu;
|
||||
}
|
||||
|
||||
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
|
||||
|
||||
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
|
||||
rs_ptr[row] = rs;
|
||||
}
|
||||
|
||||
const bool save_z = !Has_subset || row_z > 0;
|
||||
if (save_z) {
|
||||
index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
Ovec z;
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
|
||||
compute_t g_ij = gamma[it].data.elt[jt];
|
||||
compute_t b_ij = beta[it].data.elt[jt];
|
||||
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
|
||||
}
|
||||
z.store_to(params.z, idx_z);
|
||||
idx_z += VEC_COLS_PER_LDG;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace layer_norm
|
||||
|
||||
using namespace layer_norm;
|
||||
|
||||
template<
|
||||
typename weight_t,
|
||||
typename input_t,
|
||||
typename residual_t,
|
||||
typename output_t,
|
||||
typename compute_t,
|
||||
typename index_t,
|
||||
int HIDDEN_SIZE,
|
||||
int CTAS_PER_ROW,
|
||||
int WARPS_M,
|
||||
int WARPS_N,
|
||||
int BYTES_PER_LDG
|
||||
>
|
||||
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
|
||||
|
||||
using Kernel_traits = Kernel_traits<weight_t,
|
||||
input_t,
|
||||
residual_t,
|
||||
output_t,
|
||||
compute_t,
|
||||
index_t,
|
||||
HIDDEN_SIZE,
|
||||
CTAS_PER_ROW,
|
||||
WARPS_M,
|
||||
WARPS_N,
|
||||
BYTES_PER_LDG
|
||||
>;
|
||||
bool has_colscale = launch_params.params.colscale != nullptr;
|
||||
bool has_subset = launch_params.params.x0_subset != nullptr;
|
||||
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
||||
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
||||
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
|
||||
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
||||
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
|
||||
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::Stats::stats_t)
|
||||
* 2;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
|
@ -1,172 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace layer_norm {
|
||||
template<
|
||||
uint32_t HIDDEN_SIZE_,
|
||||
typename weight_t_,
|
||||
typename input_t_,
|
||||
typename residual_t_,
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
uint32_t THREADS_PER_CTA_
|
||||
>
|
||||
struct Kernel_traits_base {
|
||||
|
||||
using weight_t = weight_t_;
|
||||
using input_t = input_t_;
|
||||
using residual_t = residual_t_;
|
||||
using output_t = output_t_;
|
||||
using compute_t = compute_t_;
|
||||
using index_t = index_t_;
|
||||
|
||||
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
|
||||
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
|
||||
enum { THREADS_PER_WARP = 32 };
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
uint32_t HIDDEN_SIZE_,
|
||||
typename weight_t_,
|
||||
typename input_t_,
|
||||
typename residual_t_,
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
bool Has_colscale,
|
||||
uint32_t THREADS_PER_CTA_,
|
||||
uint32_t BYTES_PER_LDG_,
|
||||
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
|
||||
weight_t_,
|
||||
input_t_,
|
||||
residual_t_,
|
||||
output_t_,
|
||||
compute_t_,
|
||||
index_t_,
|
||||
THREADS_PER_CTA_>
|
||||
>
|
||||
struct Kernel_traits_finalize : public Base {
|
||||
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
|
||||
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
|
||||
// Bytes per global load from the input.
|
||||
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
|
||||
// Number of elements fetched by a global load.
|
||||
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
|
||||
// Bytes per global store of the weights.
|
||||
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
|
||||
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
|
||||
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
|
||||
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
|
||||
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
|
||||
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
|
||||
|
||||
// Shared memory size to transpose the CTA result.
|
||||
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
|
||||
// Shared memory size to coalsece the CTA result.
|
||||
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
|
||||
// Shared memory requirement per CTA.
|
||||
static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
|
||||
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
|
||||
|
||||
// The type of the reducer.
|
||||
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
|
||||
|
||||
// Condition for the whole CTA to participate in syncthreads.
|
||||
static_assert(COLS % Base::THREADS_PER_WARP == 0);
|
||||
enum { CTAS = COLS / Base::THREADS_PER_WARP };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template<
|
||||
typename weight_t_,
|
||||
typename input_t_,
|
||||
typename residual_t_,
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
uint32_t HIDDEN_SIZE_,
|
||||
uint32_t CTAS_PER_ROW_,
|
||||
uint32_t WARPS_M_,
|
||||
uint32_t WARPS_N_,
|
||||
uint32_t BYTES_PER_LDG_ = 16,
|
||||
typename Base = Kernel_traits_base<
|
||||
HIDDEN_SIZE_,
|
||||
weight_t_,
|
||||
input_t_,
|
||||
residual_t_,
|
||||
output_t_,
|
||||
compute_t_,
|
||||
index_t_,
|
||||
WARPS_M_*WARPS_N_*THREADS_PER_WARP
|
||||
>
|
||||
>
|
||||
struct Kernel_traits : public Base {
|
||||
|
||||
using input_t = typename Base::input_t;
|
||||
using residual_t = typename Base::residual_t;
|
||||
using weight_t = typename Base::weight_t;
|
||||
using compute_t = typename Base::compute_t;
|
||||
using output_t = typename Base::output_t;
|
||||
using index_t = typename Base::index_t;
|
||||
// using mask_t = unsigned char;
|
||||
using mask_t = bool;
|
||||
|
||||
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
|
||||
enum { WARPS_M = WARPS_M_ };
|
||||
enum { WARPS_N = WARPS_N_ };
|
||||
enum { COLS = HIDDEN_SIZE_ };
|
||||
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
|
||||
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
|
||||
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
|
||||
|
||||
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
|
||||
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
|
||||
enum { ROWS_PER_CTA = WARPS_M };
|
||||
|
||||
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
|
||||
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
|
||||
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
|
||||
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
|
||||
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
|
||||
|
||||
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
|
||||
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
|
||||
|
||||
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
|
||||
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
|
||||
|
||||
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
|
||||
using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
|
||||
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
|
||||
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
|
||||
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
|
||||
using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
|
||||
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
|
||||
|
||||
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
|
||||
static_assert(sizeof(input_t) == sizeof(output_t));
|
||||
static_assert(sizeof(input_t) <= sizeof(residual_t));
|
||||
// The number of columns fetched per load from input: one per thread.
|
||||
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
|
||||
// The total number of vectorized loads/stores per hidden vector.
|
||||
enum { VEC_COLS = COLS / ELTS_PER_LDG };
|
||||
// The number of loads per thread for the input.
|
||||
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
|
||||
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
|
||||
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
|
||||
|
||||
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
|
||||
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layer_norm
|
|
@ -1,783 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "ln.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr uint32_t THREADS_PER_WARP = 32;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline void check_cuda_(cudaError_t status, const char *file, int line) {
|
||||
if( status != cudaSuccess ) {
|
||||
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
|
||||
exit(status);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CHECK_CUDA(ans) \
|
||||
{ check_cuda_((ans), __FILE__, __LINE__); }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
|
||||
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
|
||||
const bool configure_params) { \
|
||||
launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
|
||||
launch_params, configure_params); \
|
||||
} \
|
||||
static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_BWD_LAUNCHER( \
|
||||
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
|
||||
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
|
||||
const bool configure_params) { \
|
||||
launch_<WTYPE, \
|
||||
ITYPE, \
|
||||
RTYPE, \
|
||||
OTYPE, \
|
||||
CTYPE, \
|
||||
uint32_t, \
|
||||
HIDDEN_SIZE, \
|
||||
CTAS_PER_ROW, \
|
||||
WARPS_M, \
|
||||
WARPS_N, \
|
||||
BYTES_PER_LDG, \
|
||||
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
|
||||
} \
|
||||
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
|
||||
void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
|
||||
const bool configure_params) { \
|
||||
launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
|
||||
launch_params, configure_params); \
|
||||
} \
|
||||
static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_PARALLEL_BWD_LAUNCHER( \
|
||||
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
|
||||
void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
|
||||
const bool configure_params) { \
|
||||
launch_parallel_residual_<WTYPE, \
|
||||
ITYPE, \
|
||||
RTYPE, \
|
||||
OTYPE, \
|
||||
CTYPE, \
|
||||
uint32_t, \
|
||||
HIDDEN_SIZE, \
|
||||
CTAS_PER_ROW, \
|
||||
WARPS_M, \
|
||||
WARPS_N, \
|
||||
BYTES_PER_LDG, \
|
||||
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
|
||||
} \
|
||||
static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void operator+=(float2 & a, const float2 & b){
|
||||
a.x += b.x;
|
||||
a.y += b.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct Sum {
|
||||
inline __device__ Sum(){}
|
||||
inline __device__ T operator()(const T &a, const T &b){
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
|
||||
return __shfl_xor_sync(uint32_t(-1), x, idx);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
|
||||
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
|
||||
return __shfl_down_sync(uint32_t(-1), x, idx);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
|
||||
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct uint16 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES>
|
||||
struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<64> {
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct TypeToVec2 {};
|
||||
|
||||
template<>
|
||||
struct TypeToVec2<float> {
|
||||
using Type = float2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeToVec2<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeToVec2<nv_bfloat16> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int INDEX>
|
||||
struct Get {
|
||||
template<typename T, typename R>
|
||||
static inline __device__ R of(const T &vec);
|
||||
};
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<0>::of(const T &vec) {
|
||||
return vec.x;
|
||||
}
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<1>::of(const T &vec) {
|
||||
return vec.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<2>::of(const T &vec) {
|
||||
return vec.z;
|
||||
}
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<3>::of(const T &vec) {
|
||||
return vec.w;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Src, typename Dst>
|
||||
struct Converter{
|
||||
static inline __device__ Dst convert(const Src &from) {
|
||||
return Dst(from);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Converter<float2, half2>{
|
||||
static inline __device__ half2 convert(const float2 &x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Converter<float2, nv_bfloat162>{
|
||||
static inline __device__ nv_bfloat162 convert(const float2 &x) {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __float22bfloat162_rn(x);
|
||||
#else
|
||||
union {
|
||||
nv_bfloat162 raw;
|
||||
nv_bfloat16 x;
|
||||
nv_bfloat16 y;
|
||||
} tmp;
|
||||
tmp.x = __float2bfloat16_rn(x.x);
|
||||
tmp.y = __float2bfloat16_rn(x.y);
|
||||
return tmp.raw;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct Zeros{
|
||||
static inline __device__ T get() {
|
||||
return T(0.f);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Zeros<float2>{
|
||||
static inline __device__ float2 get() {
|
||||
return make_float2(0.f, 0.f);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Elt_type, uint32_t NUM_ELT>
|
||||
struct Vec {
|
||||
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
|
||||
Alias_type data;
|
||||
|
||||
template<typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void zero_() {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = Elt_type(0.f);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
|
||||
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
|
||||
}
|
||||
|
||||
inline __device__ void store_to(void *base_ptr, const size_t idx) {
|
||||
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<uint32_t CTAS_PER_ROW>
|
||||
struct InterCTASync {
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
|
||||
: phase_counter_(0)
|
||||
, b0_(params.barrier + bidm) // The barrier for this group of CTAs.
|
||||
, b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
|
||||
{
|
||||
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
|
||||
}
|
||||
|
||||
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
|
||||
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
|
||||
for( int found = -1; found != expected; ) {
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void sync(){
|
||||
// ALL THREADS MUST ENTER!
|
||||
|
||||
// We switch barrier every iteration.
|
||||
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
|
||||
// We decrement every other iteration.
|
||||
bool dec = phase_counter_ & 0x2;
|
||||
int step = dec ? -1 : 1;
|
||||
int expected = dec ? 0 : CTAS_PER_ROW;
|
||||
// There are only 4 phases: up/down for b0/b1.
|
||||
phase_counter_ = (phase_counter_ + 1) & 0x3;
|
||||
|
||||
if( threadIdx.x == 0 ) {
|
||||
spin_wait_(barrier, step, expected);
|
||||
}
|
||||
// CTA waits for thread 0
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
int phase_counter_;
|
||||
int * b0_;
|
||||
int * b1_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
|
||||
|
||||
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
|
||||
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
|
||||
using Type = typename Base::Type;
|
||||
|
||||
enum { SMEM_BYTES = Base::SMEM_BYTES };
|
||||
|
||||
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
|
||||
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
|
||||
|
||||
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
|
||||
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, inter_cta_(params, bidm, bidn)
|
||||
, bidn_(bidn) // CTA id within the group.
|
||||
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
|
||||
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T allreduce(T data, Op &op) {
|
||||
data = Base::reduce(data, op);
|
||||
// We switch workspace every iteration.
|
||||
T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
|
||||
|
||||
// Warp leaders 0 hold the CTA-local results.
|
||||
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
|
||||
workspace[bidn_] = data;
|
||||
}
|
||||
inter_cta_.sync();
|
||||
static_assert(CTAS_PER_ROW <= 32);
|
||||
T total = Zeros<T>::get();
|
||||
if(this->lane_ < CTAS_PER_ROW){
|
||||
total = workspace[this->lane_];
|
||||
}
|
||||
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
InterCTASync inter_cta_;
|
||||
|
||||
T *w0_;
|
||||
T *w1_;
|
||||
int bidn_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M>
|
||||
struct Reducer<T, 1, WARPS_M, 1> {
|
||||
|
||||
using Type = T;
|
||||
enum { SMEM_BYTES = 0 };
|
||||
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
|
||||
|
||||
enum { THREADS_PER_WARP = 32 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: warp_n_(warp_n)
|
||||
, lane_(lane)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
static inline __device__ T allreduce_(T data, Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
|
||||
data = op(data, warp_shuffle_xor(data, it));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T allreduce(T data, Op &op) {
|
||||
return allreduce_(data, op);
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T reduce(T data, Op &op){
|
||||
// only lane 0 holds the result!
|
||||
#pragma unroll
|
||||
for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
|
||||
data = op(data, warp_shuffle_down(data, it));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
int warp_n_;
|
||||
int lane_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
|
||||
|
||||
using Base = Reducer<T, 1, WARPS_M, 1>;
|
||||
|
||||
using Type = T;
|
||||
|
||||
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
|
||||
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
|
||||
|
||||
enum { THREADS_PER_WARP = 32 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, use0_(true)
|
||||
{
|
||||
smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
|
||||
smem1_ = smem0_ + WARPS_M * WARPS_N;
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T allreduce(T data, Op & op) {
|
||||
T * smem = use0_ ? smem0_ : smem1_;
|
||||
use0_ = !use0_;
|
||||
data = Base::reduce(data, op);
|
||||
if( this->lane_ == 0 ) {
|
||||
smem[this->warp_n_] = data;
|
||||
}
|
||||
__syncthreads();
|
||||
T out = Zeros<T>::get();
|
||||
#pragma unroll
|
||||
for( int it = 0; it < WARPS_N; it++ ) {
|
||||
out = op(out, smem[it]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T reduce(T data, Op &op) {
|
||||
T * smem = use0_ ? smem0_ : smem1_;
|
||||
use0_ = !use0_;
|
||||
// only intra-CTA group leader holds the result!
|
||||
data = Base::reduce(data, op);
|
||||
if( this->lane_ == 0 ) {
|
||||
smem[this->warp_n_] = data;
|
||||
}
|
||||
__syncthreads();
|
||||
T out = Zeros<T>::get();
|
||||
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < WARPS_N; it++ ) {
|
||||
out = op(out, smem[it]);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
T * smem0_;
|
||||
T * smem1_;
|
||||
bool use0_;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, typename int_t>
|
||||
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){
|
||||
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
|
||||
const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
|
||||
|
||||
#pragma unroll
|
||||
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
|
||||
// Exchange
|
||||
int_t n_b = warp_shuffle_down(n_a, step);
|
||||
T m_b = warp_shuffle_down(m_a, step);
|
||||
T m2_b = warp_shuffle_down(m2_a, step);
|
||||
|
||||
// Update
|
||||
const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.
|
||||
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
|
||||
const T delta = m_a - m_b;
|
||||
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
|
||||
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
|
||||
|
||||
n_a = n_ab;
|
||||
m_a = m_ab;
|
||||
m2_a = m2_ab;
|
||||
}
|
||||
// Intra-warp broadcast (only lane 0 has valid stats).
|
||||
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
|
||||
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Stats {
|
||||
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
|
||||
|
||||
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
|
||||
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
|
||||
using stats_t = typename BlockStats::stats_t;
|
||||
|
||||
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: inter_cta_(params, bidm, bidn)
|
||||
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, bidn_(bidn) // CTA id within the group.
|
||||
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
|
||||
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
|
||||
, warp_n_(warp_n)
|
||||
, lane_(lane)
|
||||
{
|
||||
}
|
||||
|
||||
template<uint32_t N>
|
||||
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
|
||||
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
|
||||
// TODO rn is not really needed here..
|
||||
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
|
||||
stats_t block_stats = block_stats_.compute(elts, block_rn);
|
||||
|
||||
stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
|
||||
|
||||
if( warp_n_ == 0 && lane_ == 0 ) {
|
||||
workspace[bidn_] = block_stats;
|
||||
}
|
||||
|
||||
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
|
||||
inter_cta_.sync();
|
||||
|
||||
T n = Zeros<T>::get();
|
||||
T m = Zeros<T>::get();
|
||||
T m2 = Zeros<T>::get();
|
||||
|
||||
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
|
||||
static_assert(CTAS_PER_ROW <= 32);
|
||||
|
||||
// Every warp does the final reduction locally.
|
||||
if( lane_ < CTAS_PER_ROW ) {
|
||||
stats_t result = workspace[lane_];
|
||||
n = ELTS_PER_ROW_PER_CTA;
|
||||
m = layer_norm::Get<0>::of<stats_t, T>(result);
|
||||
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
|
||||
}
|
||||
|
||||
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
|
||||
|
||||
return { m, m2 };
|
||||
}
|
||||
|
||||
InterCTASync inter_cta_;
|
||||
BlockStats block_stats_;
|
||||
|
||||
stats_t *w0_;
|
||||
stats_t *w1_;
|
||||
int bidn_;
|
||||
int warp_n_;
|
||||
int lane_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Stats<T, 1, WARPS_M, WARPS_N> {
|
||||
|
||||
using WarpStats = Stats<T, 1, WARPS_M, 1>;
|
||||
using stats_t = typename WarpStats::stats_t;
|
||||
|
||||
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, use0_(true)
|
||||
{
|
||||
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
|
||||
smem1_ = smem0_ + WARPS_M * WARPS_N;
|
||||
}
|
||||
|
||||
template<bool Is_even_cols, uint32_t N, typename function_t>
|
||||
inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
|
||||
function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
|
||||
stats_t * smem = use0_ ? smem0_ : smem1_;
|
||||
use0_ = !use0_;
|
||||
// Compute warp local for all WARPS_N
|
||||
const auto warp_n = warp_stats_.reducer_.warp_n_;
|
||||
const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));
|
||||
stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(
|
||||
elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts
|
||||
);
|
||||
|
||||
//Each warp warp leader stores its stats
|
||||
const auto lane = warp_stats_.reducer_.lane_;
|
||||
if( lane == 0 ) {
|
||||
smem[warp_n] = warp_stats;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = 0;;
|
||||
T m = Zeros<T>::get();
|
||||
T m2 = Zeros<T>::get();
|
||||
|
||||
// Assume that there are less than 32 warps, such that we can finalize with a single warp
|
||||
static_assert(WARPS_N <= 32);
|
||||
if(lane < WARPS_N){
|
||||
stats_t result = smem[lane];
|
||||
n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);
|
||||
m = layer_norm::Get<0>::of<stats_t, T>(result);
|
||||
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
|
||||
}
|
||||
|
||||
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
|
||||
|
||||
return { m, m2 };
|
||||
}
|
||||
WarpStats warp_stats_;
|
||||
stats_t * smem0_;
|
||||
stats_t * smem1_;
|
||||
bool use0_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M>
|
||||
struct Stats<T, 1, WARPS_M, 1> {
|
||||
|
||||
using stats_t = typename TypeToVec2<T>::Type;
|
||||
// The simple Warp reducer.
|
||||
using Reducer = Reducer<T, 1, WARPS_M, 1>;
|
||||
|
||||
enum { SMEM_BYTES = 0 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
{
|
||||
}
|
||||
|
||||
template<bool Is_even_cols, uint32_t N, typename function_t>
|
||||
inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
|
||||
// const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
|
||||
function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
|
||||
|
||||
auto sum = Sum<T>();
|
||||
|
||||
T m = Zeros<T>::get();
|
||||
#pragma unroll
|
||||
for( int it = 0; it < N; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_elts)) {
|
||||
m += elts[it];
|
||||
}
|
||||
}
|
||||
m = reducer_.allreduce(m, sum) * row_norm_factor;
|
||||
|
||||
T m2 = Zeros<T>::get();
|
||||
#pragma unroll
|
||||
for( int it = 0; it < N; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_elts)) {
|
||||
T diff = (elts[it] - m);
|
||||
m2 += diff * diff;
|
||||
}
|
||||
}
|
||||
m2 = reducer_.allreduce(m2, sum);
|
||||
|
||||
return {m, m2};
|
||||
}
|
||||
|
||||
Reducer reducer_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layer_norm
|
|
@ -1,25 +0,0 @@
|
|||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
Loading…
Reference in New Issue