!40515 optimize gpu layernorm
Merge pull request !40515 from kisnwang/master
This commit is contained in:
commit
c56e92b96e
|
@ -107,6 +107,90 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con
|
|||
}
|
||||
}
|
||||
|
||||
constexpr int kTileSize = 8;
|
||||
template <typename T>
|
||||
struct alignas(sizeof(T) * kTileSize) TArray {
|
||||
T data[kTileSize];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledGammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim,
|
||||
const int &mean_dim, const T &epsilon, const T *dy, const T *x,
|
||||
const T *mean, const T *var, T *dg, T *db) {
|
||||
for (int i = 0; i < kTileSize; ++i) {
|
||||
dg[i] = 0;
|
||||
db[i] = 0;
|
||||
}
|
||||
for (int i = threadIdx.x; i < row_dim; i += blockDim.x) {
|
||||
T dy_tile[kTileSize];
|
||||
T x_tile[kTileSize];
|
||||
TArray<T> *dy_tmp = reinterpret_cast<TArray<T> *>(&dy_tile);
|
||||
*dy_tmp = *reinterpret_cast<const TArray<T> *>(&dy[i * col_dim + col]);
|
||||
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
|
||||
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[i * col_dim + col]);
|
||||
T var_rsqrt = my_pow(var[i] + epsilon, -0.5);
|
||||
for (int j = 0; j < kTileSize; ++j) {
|
||||
dg[j] += dy_tile[j] * var_rsqrt * (x_tile[j] - mean[i]);
|
||||
db[j] += dy_tile[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledGammaAndBetaWarpReduce(T *dg, T *db) {
|
||||
for (int i = 0; i < kTileSize; ++i) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
dg[i] += __shfl_down_sync(0xffffffff, dg[i], delta);
|
||||
db[i] += __shfl_down_sync(0xffffffff, db[i], delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledGammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr,
|
||||
T *db_addr) {
|
||||
DynamicSharedMem<T> share_mem;
|
||||
if (threadIdx.x % WARP_SIZE == 0) {
|
||||
int offset = threadIdx.x / WARP_SIZE * 2 * kTileSize;
|
||||
for (int i = 0; i < kTileSize; ++i) {
|
||||
share_mem.addr()[offset + i * 2] = dg[i];
|
||||
share_mem.addr()[offset + i * 2 + 1] = db[i];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
int offset = (threadIdx.x + stride) * 2 * kTileSize;
|
||||
for (int i = 0; i < kTileSize; ++i) {
|
||||
share_mem.addr()[threadIdx.x * 2 * kTileSize + 2 * i] += share_mem.addr()[offset + 2 * i];
|
||||
share_mem.addr()[threadIdx.x * 2 * kTileSize + 2 * i + 1] += share_mem.addr()[offset + 2 * i + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < kTileSize; ++i) {
|
||||
dg_addr[col + i] = share_mem.addr()[2 * i];
|
||||
db_addr[col + i] = share_mem.addr()[2 * i + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TiledGammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon,
|
||||
const T *dy, const T *x, const T *mean_addr, const T *var_addr, T *dg_addr,
|
||||
T *db_addr) {
|
||||
for (int col = blockIdx.x * kTileSize; col < col_dim; col += gridDim.x * kTileSize) {
|
||||
T dg[kTileSize];
|
||||
T db[kTileSize];
|
||||
TiledGammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean_addr, var_addr, dg, db);
|
||||
TiledGammaAndBetaWarpReduce(dg, db);
|
||||
TiledGammaAndBetaBlockReduce(col, row_dim, dg, db, dg_addr, db_addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon,
|
||||
T *sum1, T *sum2, T *sum3, const T *dy, const T *x, const T *mean,
|
||||
|
@ -163,6 +247,31 @@ inline __device__ void InputThreadReduce(const int &row, const int &col_dim, con
|
|||
sum3[0] = __float2half(-2.0) * sum3[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledInputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim,
|
||||
const T &epsilon, T *sum1, T *sum2, T *sum3, const T *dy, const T *x,
|
||||
const T *mean, const T *var, const T *gamma) {
|
||||
for (int i = threadIdx.x * kTileSize; i < col_dim; i += blockDim.x * kTileSize) {
|
||||
int pos = row * col_dim + i;
|
||||
T dy_tile[kTileSize];
|
||||
T x_tile[kTileSize];
|
||||
TArray<T> *dy_tmp = reinterpret_cast<TArray<T> *>(&dy_tile);
|
||||
*dy_tmp = *reinterpret_cast<const TArray<T> *>(&dy[pos]);
|
||||
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
|
||||
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[pos]);
|
||||
|
||||
for (int j = 0; j < kTileSize; ++j) {
|
||||
T v1 = dy_tile[j] * gamma[i + j];
|
||||
T v2 = x_tile[j] - mean[row];
|
||||
sum1[0] += v1 * v2;
|
||||
sum2[0] += v1;
|
||||
sum3[0] += v2;
|
||||
}
|
||||
}
|
||||
sum1[0] = (T)(-0.5) * sum1[0] * my_pow(var[row] + epsilon, -1.5);
|
||||
sum3[0] = (T)(-2.0) * sum3[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void InputWarpReduce(T *sum1, T *sum2, T *sum3) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
|
@ -243,21 +352,71 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledInputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx,
|
||||
const T *share_mem) {
|
||||
T col_inv = (T)(1.0 / col_dim);
|
||||
T v3 = my_pow(var[row] + epsilon, -0.5);
|
||||
T v4 = share_mem[0] * col_inv * (T)(2.0);
|
||||
T v5 = (col_inv * share_mem[0] * share_mem[2] - v3 * share_mem[1]) * col_inv;
|
||||
for (int col = threadIdx.x * kTileSize; col < col_dim; col += blockDim.x * kTileSize) {
|
||||
int pos = row * col_dim + col;
|
||||
T dy_tile[kTileSize];
|
||||
T x_tile[kTileSize];
|
||||
T dx_tile[kTileSize];
|
||||
TArray<T> *dy_tmp = reinterpret_cast<TArray<T> *>(&dy_tile);
|
||||
*dy_tmp = *reinterpret_cast<const TArray<T> *>(&dy[pos]);
|
||||
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
|
||||
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[pos]);
|
||||
|
||||
for (int j = 0; j < kTileSize; ++j) {
|
||||
T v1 = dy_tile[j] * gamma[col + j];
|
||||
T v2 = x_tile[j] - mean[row];
|
||||
dx_tile[j] = v1 * v3 + v4 * v2 + v5;
|
||||
}
|
||||
TArray<T> *dx_tmp = reinterpret_cast<TArray<T> *>(&dx[pos]);
|
||||
*dx_tmp = *reinterpret_cast<TArray<T> *>(dx_tile);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TiledInputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx) {
|
||||
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
|
||||
T sum1 = 0;
|
||||
T sum2 = 0;
|
||||
T sum3 = 0;
|
||||
DynamicSharedMem<T> share_mem;
|
||||
TiledInputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma);
|
||||
InputWarpReduce(&sum1, &sum2, &sum3);
|
||||
InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr());
|
||||
TiledInputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *dy,
|
||||
const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) {
|
||||
const int thread_per_block = 256;
|
||||
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
|
||||
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
|
||||
mean, var, gamma, dx);
|
||||
|
||||
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
|
||||
// GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x,
|
||||
// mean,
|
||||
// var, dg, db);
|
||||
int param_reduce_dim = row_dim * col_dim / param_dim;
|
||||
GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim, col_dim,
|
||||
epsilon, dy, x, mean, var, dg, db);
|
||||
int grid_size = param_dim;
|
||||
if (col_dim == param_dim && grid_size % kTileSize == 0 && col_dim % kTileSize == 0) {
|
||||
TiledInputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon,
|
||||
dy, x, mean, var, gamma, dx);
|
||||
share_mem_size = thread_per_block / WARP_SIZE * 2 * kTileSize * sizeof(T);
|
||||
grid_size /= kTileSize;
|
||||
TiledGammaAndBetaPropKernel<<<grid_size, thread_per_block, share_mem_size, stream>>>(
|
||||
param_reduce_dim, param_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
|
||||
} else {
|
||||
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
|
||||
mean, var, gamma, dx);
|
||||
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
|
||||
GammaAndBetaPropKernel<<<grid_size, thread_per_block, share_mem_size, stream>>>(
|
||||
param_reduce_dim, param_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim,
|
||||
|
|
|
@ -21,6 +21,21 @@
|
|||
|
||||
constexpr int NUM_PER_THREAD_REDUCE = 4;
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr int kTileSize = 8;
|
||||
template <typename T>
|
||||
struct alignas(sizeof(T) * kTileSize) TArray {
|
||||
T data[kTileSize];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T general_sqrt(T val) {
|
||||
return (T)sqrt(static_cast<float>(val));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half general_sqrt(half val) {
|
||||
return hsqrt(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) {
|
||||
|
@ -55,11 +70,27 @@ inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *
|
|||
if (pos >= col_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
MeanAndVarAccumulation(mean, var, num, block_addr[pos]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) {
|
||||
for (int i = threadIdx.x * kTileSize; i < col_dim; i += blockDim.x * kTileSize) {
|
||||
T block_tile[kTileSize];
|
||||
TArray<T> *tmp = reinterpret_cast<TArray<T> *>(&block_tile);
|
||||
*tmp = *reinterpret_cast<const TArray<T> *>(&block_addr[i]);
|
||||
for (int j = 0; j < kTileSize; ++j) {
|
||||
num[0]++;
|
||||
T mean_new = mean[0] + (block_tile[j] - mean[0]) / num[0];
|
||||
var[0] = var[0] + (block_tile[j] - mean[0]) * (block_tile[j] - mean_new);
|
||||
mean[0] = mean_new;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void WarpReduce(T *mean, T *var, T *num) {
|
||||
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
|
||||
|
@ -120,6 +151,27 @@ inline __device__ void LayerNorm(const int &row, const int &col_dim, const int &
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void TiledLayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x,
|
||||
const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) {
|
||||
for (int col = threadIdx.x * kTileSize; col < col_dim; col += blockDim.x * kTileSize) {
|
||||
int pos = row * col_dim + col;
|
||||
T y_tile[kTileSize];
|
||||
T x_tile[kTileSize];
|
||||
|
||||
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
|
||||
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[pos]);
|
||||
|
||||
for (int j = 0; j < kTileSize; ++j) {
|
||||
int i = col + j;
|
||||
y_tile[j] = (x_tile[j] - share_mem[0]) / general_sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i];
|
||||
}
|
||||
|
||||
TArray<T> *y_tmp = reinterpret_cast<TArray<T> *>(&y[pos]);
|
||||
*y_tmp = *reinterpret_cast<TArray<T> *>(&y_tile);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x,
|
||||
const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) {
|
||||
|
@ -139,14 +191,38 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TiledLayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon,
|
||||
const T *x, const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) {
|
||||
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
|
||||
T mean = 0;
|
||||
T var = 0;
|
||||
T num = 0;
|
||||
const T *block_addr = x + row * col_dim;
|
||||
DynamicSharedMem<T> share_mem;
|
||||
|
||||
TiledThreadReduce(col_dim, block_addr, &mean, &var, &num);
|
||||
WarpReduce(&mean, &var, &num);
|
||||
BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr());
|
||||
|
||||
__syncthreads();
|
||||
TiledLayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x,
|
||||
const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) {
|
||||
const int thread_per_block = 256;
|
||||
// keep the mean/var/num after warp reduce
|
||||
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
|
||||
LayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma,
|
||||
beta, y, mean, var);
|
||||
if (col_dim == param_dim && row_dim % kTileSize == 0 && col_dim % kTileSize == 0) {
|
||||
TiledLayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x,
|
||||
gamma, beta, y, mean, var);
|
||||
} else {
|
||||
LayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x,
|
||||
gamma, beta, y, mean, var);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim,
|
||||
|
|
Loading…
Reference in New Issue