!41959 mix precision for layernormgradgrad
Merge pull request !41959 from kisnwang/support-mix-precision-gpu-layernorm
This commit is contained in:
commit
722b9c60e6
|
@ -84,8 +84,8 @@ class LayerNormGradGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
input_size_ = input_row_ * input_col_ * sizeof(T);
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_row_ * sizeof(T));
|
||||
input_size_list_.push_back(input_row_ * sizeof(T));
|
||||
input_size_list_.push_back(input_row_ * sizeof(float));
|
||||
input_size_list_.push_back(input_row_ * sizeof(float));
|
||||
input_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(param_dim_ * sizeof(T));
|
||||
|
@ -110,8 +110,8 @@ class LayerNormGradGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
// get device ptr input index output
|
||||
T *x = nullptr;
|
||||
T *dy = nullptr;
|
||||
T *var = nullptr;
|
||||
T *mean = nullptr;
|
||||
float *var = nullptr;
|
||||
float *mean = nullptr;
|
||||
T *gamma = nullptr;
|
||||
T *grad_dx = nullptr;
|
||||
T *grad_dg = nullptr;
|
||||
|
@ -130,11 +130,11 @@ class LayerNormGradGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<T>(input_ptrs, kIndex2, kernel_name_, &var);
|
||||
flag = GetDeviceAddress<float>(input_ptrs, kIndex2, kernel_name_, &var);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<T>(input_ptrs, kIndex3, kernel_name_, &mean);
|
||||
flag = GetDeviceAddress<float>(input_ptrs, kIndex3, kernel_name_, &mean);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
|
|
@ -40,8 +40,8 @@ inline __device__ half my_pow(half a, double b) {
|
|||
template <typename T>
|
||||
inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim,
|
||||
const int &mean_dim, const T &epsilon, const T *dy, const T *x,
|
||||
const T *mean, const T *var, const T *grad_dx, T *part1, T *part2,
|
||||
T *part3, const T *global_sum1, const T *global_sum2) {
|
||||
const float *mean, const float *var, const T *grad_dx, T *part1,
|
||||
T *part2, T *part3, const T *global_sum1, const T *global_sum2) {
|
||||
int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
|
@ -53,8 +53,8 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d
|
|||
int pos = row * col_dim + col;
|
||||
int mean_offset = pos / mean_dim;
|
||||
|
||||
T v1 = x[pos] - mean[mean_offset];
|
||||
T v2 = my_pow(var[mean_offset] + epsilon, -0.5);
|
||||
T v1 = x[pos] - static_cast<T>(mean[mean_offset]);
|
||||
T v2 = my_pow(static_cast<T>(var[mean_offset]) + epsilon, -0.5);
|
||||
|
||||
part1[0] += dy[pos] * v1 * v2 * global_sum2[pos];
|
||||
part2[0] += dy[pos] * global_sum1[pos];
|
||||
|
@ -103,7 +103,7 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di
|
|||
|
||||
template <typename T>
|
||||
__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *grad_dx,
|
||||
const T *dy, const T *x, const float *mean, const float *var, const T *grad_dx,
|
||||
T *d_gamma, T *global_sum1, T *global_sum2) {
|
||||
for (int col = blockIdx.x; col < col_dim; col += gridDim.x) {
|
||||
T part1 = 0;
|
||||
|
@ -119,7 +119,7 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con
|
|||
template <typename T>
|
||||
inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col_dim, const int ¶m_dim,
|
||||
const T &epsilon, T *sum1, T *sum2, T *sum3, T *sum4, const T *dy,
|
||||
const T *x, const T *mean, const T *var, const T *gamma,
|
||||
const T *x, const float *mean, const float *var, const T *gamma,
|
||||
const T *grad_dx) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
|
@ -131,8 +131,8 @@ inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col
|
|||
int pos = row * col_dim + col;
|
||||
int gamma_offset = pos % param_dim;
|
||||
|
||||
T v1 = x[pos] - mean[row];
|
||||
T v2 = my_pow(var[row] + epsilon, -0.5);
|
||||
T v1 = x[pos] - static_cast<T>(mean[row]);
|
||||
T v2 = my_pow(static_cast<T>(var[row]) + epsilon, -0.5);
|
||||
T v3 = v1 * v2;
|
||||
T v4 = dy[pos] * gamma[gamma_offset];
|
||||
|
||||
|
@ -183,8 +183,8 @@ inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T
|
|||
template <typename T>
|
||||
inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim,
|
||||
const T &epsilon, T *sum5, T *sum6, T *sum7, T *share_mem,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
|
||||
const T *grad_dx, const T *grad_dg, T *d_x) {
|
||||
const T *dy, const T *x, const float *mean, const float *var,
|
||||
const T *gamma, const T *grad_dx, const T *grad_dg, T *d_x) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) {
|
||||
|
@ -220,8 +220,8 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col
|
|||
template <>
|
||||
inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim,
|
||||
const half &epsilon, half *sum5, half *sum6, half *sum7,
|
||||
half *share_mem, const half *dy, const half *x, const half *mean,
|
||||
const half *var, const half *gamma, const half *grad_dx,
|
||||
half *share_mem, const half *dy, const half *x, const float *mean,
|
||||
const float *var, const half *gamma, const half *grad_dx,
|
||||
const half *grad_dg, half *d_x) {
|
||||
int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE;
|
||||
for (int i = threadIdx.x; i < loop_num; i += blockDim.x) {
|
||||
|
@ -233,8 +233,8 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col
|
|||
int pos = row * col_dim + col;
|
||||
int gamma_offset = pos % param_dim;
|
||||
|
||||
half v1 = x[pos] - mean[row];
|
||||
half v2 = my_pow(var[row] + epsilon, -0.5);
|
||||
half v1 = x[pos] - __float2half(mean[row]);
|
||||
half v2 = my_pow(__float2half(var[row]) + epsilon, -0.5);
|
||||
half v3 = dy[pos] * gamma[gamma_offset];
|
||||
half v4 = v3 - share_mem[2] * __float2half(1.0 / col_dim) - v1 * v2 * share_mem[3] * __float2half(1.0 / col_dim);
|
||||
half v5 = v3 * share_mem[1] * __float2half(1.0 / col_dim);
|
||||
|
@ -290,9 +290,9 @@ inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T
|
|||
|
||||
template <typename T>
|
||||
inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon,
|
||||
const T *dy, const T *x, const T *mean, const T *var, const T *gamma, const T *grad_dx,
|
||||
const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, const T *share_mem,
|
||||
T *global_sum1, T *global_sum2) {
|
||||
const T *dy, const T *x, const float *mean, const float *var, const T *gamma,
|
||||
const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x,
|
||||
const T *share_mem, T *global_sum1, T *global_sum2) {
|
||||
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
|
||||
int pos = (row * col_dim + col);
|
||||
int gamma_offset = pos % param_dim;
|
||||
|
@ -317,15 +317,15 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int &
|
|||
|
||||
template <>
|
||||
inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon,
|
||||
const half *dy, const half *x, const half *mean, const half *var, const half *gamma,
|
||||
const half *dy, const half *x, const float *mean, const float *var, const half *gamma,
|
||||
const half *grad_dx, const half *grad_dg, const half *grad_db, half *d_dy, half *d_x,
|
||||
const half *share_mem, half *global_sum1, half *global_sum2) {
|
||||
for (int col = threadIdx.x; col < col_dim; col += blockDim.x) {
|
||||
int pos = (row * col_dim + col);
|
||||
int gamma_offset = pos % param_dim;
|
||||
|
||||
half v1 = x[pos] - mean[row];
|
||||
half v2 = my_pow(var[row] + epsilon, -0.5);
|
||||
half v1 = x[pos] - __float2half(mean[row]);
|
||||
half v2 = my_pow(__float2half(var[row]) + epsilon, -0.5);
|
||||
half v3 = v1 * v2;
|
||||
|
||||
half part1 = gamma[gamma_offset] * grad_dx[pos] * v2;
|
||||
|
@ -334,8 +334,8 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int &
|
|||
half part4 = v3 * grad_dg[gamma_offset];
|
||||
d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset];
|
||||
|
||||
half part5 =
|
||||
v1 * (my_pow(var[row] + epsilon, -1.5) * ((share_mem[4] + share_mem[5]) * __float2half(-1.0 / col_dim)));
|
||||
half part5 = v1 * (my_pow(__float2half(var[row]) + epsilon, -1.5) *
|
||||
((share_mem[4] + share_mem[5]) * __float2half(-1.0 / col_dim)));
|
||||
d_x[pos] += part5 + share_mem[6] * __float2half(1.0 / col_dim);
|
||||
|
||||
global_sum1[pos] = share_mem[0] * __float2half(1.0 / col_dim);
|
||||
|
@ -345,7 +345,7 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int &
|
|||
|
||||
template <typename T>
|
||||
__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *dy,
|
||||
const T *x, const T *mean, const T *var, const T *gamma, const T *grad_dx,
|
||||
const T *x, const float *mean, const float *var, const T *gamma, const T *grad_dx,
|
||||
const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *global_sum1, T *global_sum2) {
|
||||
for (int row = blockIdx.x; row < row_dim; row += gridDim.x) {
|
||||
T sum1 = 0;
|
||||
|
@ -373,9 +373,9 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int
|
|||
|
||||
template <typename T>
|
||||
void CalLayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1, T *global_sum2,
|
||||
const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, const T *gamma,
|
||||
const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *d_gamma,
|
||||
cudaStream_t stream) {
|
||||
const T &epsilon, const T *dy, const T *x, const float *mean, const float *var,
|
||||
const T *gamma, const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x,
|
||||
T *d_gamma, cudaStream_t stream) {
|
||||
int share_mem_size = THREAD_PER_BLOCK / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T);
|
||||
InputPropKernel<<<row_dim, THREAD_PER_BLOCK, share_mem_size, stream>>>(row_dim, col_dim, param_dim, epsilon, dy, x,
|
||||
mean, var, gamma, grad_dx, grad_dg, grad_db,
|
||||
|
@ -394,7 +394,7 @@ template CUDA_LIB_EXPORT void CalLayerNormGradGrad(const int &row_dim, const int
|
|||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalLayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim,
|
||||
half *global_sum1, half *global_sum2, const half &epsilon,
|
||||
const half *dy, const half *x, const half *mean, const half *var,
|
||||
const half *dy, const half *x, const float *mean, const float *var,
|
||||
const half *gamma, const half *grad_dx, const half *grad_dg,
|
||||
const half *grad_db, half *d_dy, half *d_x, half *d_gamma,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalLayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1,
|
||||
T *global_sum2, const T &epsilon, const T *dy, const T *x, const T *mean,
|
||||
const T *var, const T *gamma, const T *grad_dx, const T *grad_dg,
|
||||
T *global_sum2, const T &epsilon, const T *dy, const T *x, const float *mean,
|
||||
const float *var, const T *gamma, const T *grad_dx, const T *grad_dg,
|
||||
const T *grad_db, T *d_dy, T *d_x, T *d_gamma, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAYER_NORM_GRAD_GRAD_IMPL_CUH_
|
||||
|
|
|
@ -50,8 +50,8 @@ const std::vector<std::pair<KernelAttr, LayerNormGradGradPtrCreatorFunc>> kernel
|
|||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
|
|
@ -35,17 +35,6 @@ AbstractBasePtr LayerNormGradGradInfer(const abstract::AnalysisEnginePtr &, cons
|
|||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); // x
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); // dy
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex4]); // gamma
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("dy", input_args[kInputIndex1]->BuildType());
|
||||
(void)types.emplace("variance", input_args[kInputIndex2]->BuildType());
|
||||
(void)types.emplace("mean", input_args[kInputIndex3]->BuildType());
|
||||
(void)types.emplace("gamma", input_args[kInputIndex4]->BuildType());
|
||||
(void)types.emplace("d_dx", input_args[kInputIndex5]->BuildType());
|
||||
(void)types.emplace("d_dg", input_args[kInputIndex6]->BuildType());
|
||||
(void)types.emplace("d_db", input_args[kInputIndex7]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto d_dx_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
||||
|
|
Loading…
Reference in New Issue