!40515 optimize gpu layernorm

Merge pull request !40515 from kisnwang/master
This commit is contained in:
i-robot 2022-08-23 06:34:32 +00:00 committed by Gitee
commit c56e92b96e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 245 additions and 10 deletions

View File

@ -107,6 +107,90 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con
}
}
constexpr int kTileSize = 8;
template <typename T>
struct alignas(sizeof(T) * kTileSize) TArray {
T data[kTileSize];
};
template <typename T>
inline __device__ void TiledGammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim,
const int &mean_dim, const T &epsilon, const T *dy, const T *x,
const T *mean, const T *var, T *dg, T *db) {
for (int i = 0; i < kTileSize; ++i) {
dg[i] = 0;
db[i] = 0;
}
for (int i = threadIdx.x; i < row_dim; i += blockDim.x) {
T dy_tile[kTileSize];
T x_tile[kTileSize];
TArray<T> *dy_tmp = reinterpret_cast<TArray<T> *>(&dy_tile);
*dy_tmp = *reinterpret_cast<const TArray<T> *>(&dy[i * col_dim + col]);
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[i * col_dim + col]);
T var_rsqrt = my_pow(var[i] + epsilon, -0.5);
for (int j = 0; j < kTileSize; ++j) {
dg[j] += dy_tile[j] * var_rsqrt * (x_tile[j] - mean[i]);
db[j] += dy_tile[j];
}
}
}
template <typename T>
inline __device__ void TiledGammaAndBetaWarpReduce(T *dg, T *db) {
for (int i = 0; i < kTileSize; ++i) {
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
dg[i] += __shfl_down_sync(0xffffffff, dg[i], delta);
db[i] += __shfl_down_sync(0xffffffff, db[i], delta);
}
}
}
template <typename T>
inline __device__ void TiledGammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr,
T *db_addr) {
DynamicSharedMem<T> share_mem;
if (threadIdx.x % WARP_SIZE == 0) {
int offset = threadIdx.x / WARP_SIZE * 2 * kTileSize;
for (int i = 0; i < kTileSize; ++i) {
share_mem.addr()[offset + i * 2] = dg[i];
share_mem.addr()[offset + i * 2 + 1] = db[i];
}
}
__syncthreads();
for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
int offset = (threadIdx.x + stride) * 2 * kTileSize;
for (int i = 0; i < kTileSize; ++i) {
share_mem.addr()[threadIdx.x * 2 * kTileSize + 2 * i] += share_mem.addr()[offset + 2 * i];
share_mem.addr()[threadIdx.x * 2 * kTileSize + 2 * i + 1] += share_mem.addr()[offset + 2 * i + 1];
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < kTileSize; ++i) {
dg_addr[col + i] = share_mem.addr()[2 * i];
db_addr[col + i] = share_mem.addr()[2 * i + 1];
}
}
}
template <typename T>
__global__ void TiledGammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon,
const T *dy, const T *x, const T *mean_addr, const T *var_addr, T *dg_addr,
T *db_addr) {
for (int col = blockIdx.x * kTileSize; col < col_dim; col += gridDim.x * kTileSize) {
T dg[kTileSize];
T db[kTileSize];
TiledGammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean_addr, var_addr, dg, db);
TiledGammaAndBetaWarpReduce(dg, db);
TiledGammaAndBetaBlockReduce(col, row_dim, dg, db, dg_addr, db_addr);
}
}
template <typename T>
inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int &param_dim, const T &epsilon,
T *sum1, T *sum2, T *sum3, const T *dy, const T *x, const T *mean,
@ -163,6 +247,31 @@ inline __device__ void InputThreadReduce(const int &row, const int &col_dim, con
sum3[0] = __float2half(-2.0) * sum3[0];
}
template <typename T>
inline __device__ void TiledInputThreadReduce(const int &row, const int &col_dim, const int &param_dim,
const T &epsilon, T *sum1, T *sum2, T *sum3, const T *dy, const T *x,
const T *mean, const T *var, const T *gamma) {
for (int i = threadIdx.x * kTileSize; i < col_dim; i += blockDim.x * kTileSize) {
int pos = row * col_dim + i;
T dy_tile[kTileSize];
T x_tile[kTileSize];
TArray<T> *dy_tmp = reinterpret_cast<TArray<T> *>(&dy_tile);
*dy_tmp = *reinterpret_cast<const TArray<T> *>(&dy[pos]);
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[pos]);
for (int j = 0; j < kTileSize; ++j) {
T v1 = dy_tile[j] * gamma[i + j];
T v2 = x_tile[j] - mean[row];
sum1[0] += v1 * v2;
sum2[0] += v1;
sum3[0] += v2;
}
}
sum1[0] = (T)(-0.5) * sum1[0] * my_pow(var[row] + epsilon, -1.5);
sum3[0] = (T)(-2.0) * sum3[0];
}
template <typename T>
inline __device__ void InputWarpReduce(T *sum1, T *sum2, T *sum3) {
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
@ -243,21 +352,71 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
}
}
template <typename T>
inline __device__ void TiledInputProp(const int &row, const int &col_dim, const int &param_dim, const T &epsilon,
const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx,
const T *share_mem) {
T col_inv = (T)(1.0 / col_dim);
T v3 = my_pow(var[row] + epsilon, -0.5);
T v4 = share_mem[0] * col_inv * (T)(2.0);
T v5 = (col_inv * share_mem[0] * share_mem[2] - v3 * share_mem[1]) * col_inv;
for (int col = threadIdx.x * kTileSize; col < col_dim; col += blockDim.x * kTileSize) {
int pos = row * col_dim + col;
T dy_tile[kTileSize];
T x_tile[kTileSize];
T dx_tile[kTileSize];
TArray<T> *dy_tmp = reinterpret_cast<TArray<T> *>(&dy_tile);
*dy_tmp = *reinterpret_cast<const TArray<T> *>(&dy[pos]);
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[pos]);
for (int j = 0; j < kTileSize; ++j) {
T v1 = dy_tile[j] * gamma[col + j];
T v2 = x_tile[j] - mean[row];
dx_tile[j] = v1 * v3 + v4 * v2 + v5;
}
TArray<T> *dx_tmp = reinterpret_cast<TArray<T> *>(&dx[pos]);
*dx_tmp = *reinterpret_cast<TArray<T> *>(dx_tile);
}
}
template <typename T>
__global__ void TiledInputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon,
const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx) {
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
T sum1 = 0;
T sum2 = 0;
T sum3 = 0;
DynamicSharedMem<T> share_mem;
TiledInputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma);
InputWarpReduce(&sum1, &sum2, &sum3);
InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr());
TiledInputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr());
}
}
template <typename T>
void LayerNormGrad(const int &row_dim, const int &col_dim, const int &param_dim, const T &epsilon, const T *dy,
const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) {
const int thread_per_block = 256;
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
mean, var, gamma, dx);
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
// GammaAndBetaPropKernel<<<col_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, epsilon, dy, x,
// mean,
// var, dg, db);
int param_reduce_dim = row_dim * col_dim / param_dim;
GammaAndBetaPropKernel<<<param_dim, thread_per_block, share_mem_size, stream>>>(param_reduce_dim, param_dim, col_dim,
epsilon, dy, x, mean, var, dg, db);
int grid_size = param_dim;
if (col_dim == param_dim && grid_size % kTileSize == 0 && col_dim % kTileSize == 0) {
TiledInputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon,
dy, x, mean, var, gamma, dx);
share_mem_size = thread_per_block / WARP_SIZE * 2 * kTileSize * sizeof(T);
grid_size /= kTileSize;
TiledGammaAndBetaPropKernel<<<grid_size, thread_per_block, share_mem_size, stream>>>(
param_reduce_dim, param_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
} else {
InputPropKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
mean, var, gamma, dx);
share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T);
GammaAndBetaPropKernel<<<grid_size, thread_per_block, share_mem_size, stream>>>(
param_reduce_dim, param_dim, col_dim, epsilon, dy, x, mean, var, dg, db);
}
}
template CUDA_LIB_EXPORT void LayerNormGrad(const int &row_dim, const int &col_dim, const int &param_dim,

View File

@ -21,6 +21,21 @@
constexpr int NUM_PER_THREAD_REDUCE = 4;
constexpr int WARP_SIZE = 32;
constexpr int kTileSize = 8;
template <typename T>
struct alignas(sizeof(T) * kTileSize) TArray {
T data[kTileSize];
};
template <typename T>
inline __device__ T general_sqrt(T val) {
return (T)sqrt(static_cast<float>(val));
}
template <>
inline __device__ half general_sqrt(half val) {
return hsqrt(val);
}
template <typename T>
inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) {
@ -55,11 +70,27 @@ inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *
if (pos >= col_dim) {
return;
}
MeanAndVarAccumulation(mean, var, num, block_addr[pos]);
}
}
}
template <typename T>
inline __device__ void TiledThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) {
for (int i = threadIdx.x * kTileSize; i < col_dim; i += blockDim.x * kTileSize) {
T block_tile[kTileSize];
TArray<T> *tmp = reinterpret_cast<TArray<T> *>(&block_tile);
*tmp = *reinterpret_cast<const TArray<T> *>(&block_addr[i]);
for (int j = 0; j < kTileSize; ++j) {
num[0]++;
T mean_new = mean[0] + (block_tile[j] - mean[0]) / num[0];
var[0] = var[0] + (block_tile[j] - mean[0]) * (block_tile[j] - mean_new);
mean[0] = mean_new;
}
}
}
template <typename T>
inline __device__ void WarpReduce(T *mean, T *var, T *num) {
for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) {
@ -120,6 +151,27 @@ inline __device__ void LayerNorm(const int &row, const int &col_dim, const int &
}
}
template <typename T>
inline __device__ void TiledLayerNorm(const int &row, const int &col_dim, const int &param_dim, const T *x,
const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) {
for (int col = threadIdx.x * kTileSize; col < col_dim; col += blockDim.x * kTileSize) {
int pos = row * col_dim + col;
T y_tile[kTileSize];
T x_tile[kTileSize];
TArray<T> *x_tmp = reinterpret_cast<TArray<T> *>(x_tile);
*x_tmp = *reinterpret_cast<const TArray<T> *>(&x[pos]);
for (int j = 0; j < kTileSize; ++j) {
int i = col + j;
y_tile[j] = (x_tile[j] - share_mem[0]) / general_sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i];
}
TArray<T> *y_tmp = reinterpret_cast<TArray<T> *>(&y[pos]);
*y_tmp = *reinterpret_cast<TArray<T> *>(&y_tile);
}
}
template <typename T>
__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x,
const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) {
@ -139,14 +191,38 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int
}
}
template <typename T>
__global__ void TiledLayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon,
const T *x, const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) {
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
T mean = 0;
T var = 0;
T num = 0;
const T *block_addr = x + row * col_dim;
DynamicSharedMem<T> share_mem;
TiledThreadReduce(col_dim, block_addr, &mean, &var, &num);
WarpReduce(&mean, &var, &num);
BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr());
__syncthreads();
TiledLayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y);
}
}
template <typename T>
void LayerNorm(const int &row_dim, const int &col_dim, const int &param_dim, const T &epsilon, const T *x,
const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) {
const int thread_per_block = 256;
// keep the mean/var/num after warp reduce
int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T);
LayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x, gamma,
beta, y, mean, var);
if (col_dim == param_dim && row_dim % kTileSize == 0 && col_dim % kTileSize == 0) {
TiledLayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x,
gamma, beta, y, mean, var);
} else {
LayerNormKernel<<<row_dim, thread_per_block, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, x,
gamma, beta, y, mean, var);
}
}
template CUDA_LIB_EXPORT void LayerNorm(const int &row_dim, const int &col_dim, const int &param_dim,