From 16a4fa4846523575eb8fff08973aa1ffc80d8aab Mon Sep 17 00:00:00 2001 From: kswang Date: Wed, 14 Sep 2022 16:13:42 +0800 Subject: [PATCH] mix precision for layernorm gradgrad --- .../cuda_class/layer_norm_grad_grad_helper.h | 12 ++--- .../cuda_ops/layer_norm_grad_grad_impl.cu | 54 +++++++++---------- .../cuda_ops/layer_norm_grad_grad_impl.cuh | 4 +- .../nn/layer_norm_grad_grad_gpu_kernel.cc | 4 +- .../core/ops/grad/layer_norm_grad_grad.cc | 11 ---- 5 files changed, 37 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/layer_norm_grad_grad_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/layer_norm_grad_grad_helper.h index edeab60bcb9..8ab97564103 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/layer_norm_grad_grad_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/layer_norm_grad_grad_helper.h @@ -84,8 +84,8 @@ class LayerNormGradGradHelperGpuKernel : public GpuKernelHelperBase { input_size_ = input_row_ * input_col_ * sizeof(T); input_size_list_.push_back(input_size_); input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_row_ * sizeof(T)); - input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(float)); + input_size_list_.push_back(input_row_ * sizeof(float)); input_size_list_.push_back(param_dim_ * sizeof(T)); input_size_list_.push_back(input_size_); input_size_list_.push_back(param_dim_ * sizeof(T)); @@ -110,8 +110,8 @@ class LayerNormGradGradHelperGpuKernel : public GpuKernelHelperBase { // get device ptr input index output T *x = nullptr; T *dy = nullptr; - T *var = nullptr; - T *mean = nullptr; + float *var = nullptr; + float *mean = nullptr; T *gamma = nullptr; T *grad_dx = nullptr; T *grad_dg = nullptr; @@ -130,11 +130,11 @@ class LayerNormGradGradHelperGpuKernel : public GpuKernelHelperBase { if (flag != 0) { return flag; } - flag = GetDeviceAddress(input_ptrs, kIndex2, kernel_name_, &var); + flag = GetDeviceAddress(input_ptrs, kIndex2, kernel_name_, &var); if (flag != 0) { return flag; } - flag = GetDeviceAddress(input_ptrs, kIndex3, kernel_name_, &mean); + flag = GetDeviceAddress(input_ptrs, kIndex3, kernel_name_, &mean); if (flag != 0) { return flag; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cu index d0c8951b722..5160b20bf58 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cu @@ -40,8 +40,8 @@ inline __device__ half my_pow(half a, double b) { template inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, const int &mean_dim, const T &epsilon, const T *dy, const T *x, - const T *mean, const T *var, const T *grad_dx, T *part1, T *part2, - T *part3, const T *global_sum1, const T *global_sum2) { + const float *mean, const float *var, const T *grad_dx, T *part1, + T *part2, T *part3, const T *global_sum1, const T *global_sum2) { int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -53,8 +53,8 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d int pos = row * col_dim + col; int mean_offset = pos / mean_dim; - T v1 = x[pos] - mean[mean_offset]; - T v2 = my_pow(var[mean_offset] + epsilon, -0.5); + T v1 = x[pos] - static_cast(mean[mean_offset]); + T v2 = my_pow(static_cast(var[mean_offset]) + epsilon, -0.5); part1[0] += dy[pos] * v1 * v2 * global_sum2[pos]; part2[0] += dy[pos] * global_sum1[pos]; @@ -103,7 +103,7 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di template __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon, - const T *dy, const T *x, const T *mean, const T *var, const T *grad_dx, + const T *dy, const T *x, const float *mean, const float *var, const T *grad_dx, T *d_gamma, T *global_sum1, T *global_sum2) { for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { T part1 = 0; @@ -119,7 +119,7 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con template inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, T *sum1, T *sum2, T *sum3, T *sum4, const T *dy, - const T *x, const T *mean, const T *var, const T *gamma, + const T *x, const float *mean, const float *var, const T *gamma, const T *grad_dx) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { @@ -131,8 +131,8 @@ inline __device__ void InputThreadReduceInnerMean(const int &row, const int &col int pos = row * col_dim + col; int gamma_offset = pos % param_dim; - T v1 = x[pos] - mean[row]; - T v2 = my_pow(var[row] + epsilon, -0.5); + T v1 = x[pos] - static_cast(mean[row]); + T v2 = my_pow(static_cast(var[row]) + epsilon, -0.5); T v3 = v1 * v2; T v4 = dy[pos] * gamma[gamma_offset]; @@ -183,8 +183,8 @@ inline __device__ void InputBlockReduceInnerMean(const int &col_dim, T *sum1, T template inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, T *sum5, T *sum6, T *sum7, T *share_mem, - const T *dy, const T *x, const T *mean, const T *var, const T *gamma, - const T *grad_dx, const T *grad_dg, T *d_x) { + const T *dy, const T *x, const float *mean, const float *var, + const T *gamma, const T *grad_dx, const T *grad_dg, T *d_x) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -220,8 +220,8 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col template <> inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, half *sum5, half *sum6, half *sum7, - half *share_mem, const half *dy, const half *x, const half *mean, - const half *var, const half *gamma, const half *grad_dx, + half *share_mem, const half *dy, const half *x, const float *mean, + const float *var, const half *gamma, const half *grad_dx, const half *grad_dg, half *d_x) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { @@ -233,8 +233,8 @@ inline __device__ void InputThreadReduceOuterMean(const int &row, const int &col int pos = row * col_dim + col; int gamma_offset = pos % param_dim; - half v1 = x[pos] - mean[row]; - half v2 = my_pow(var[row] + epsilon, -0.5); + half v1 = x[pos] - __float2half(mean[row]); + half v2 = my_pow(__float2half(var[row]) + epsilon, -0.5); half v3 = dy[pos] * gamma[gamma_offset]; half v4 = v3 - share_mem[2] * __float2half(1.0 / col_dim) - v1 * v2 * share_mem[3] * __float2half(1.0 / col_dim); half v5 = v3 * share_mem[1] * __float2half(1.0 / col_dim); @@ -290,9 +290,9 @@ inline __device__ void InputBlockReduceOuterMean(const int &col_dim, T *sum5, T template inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, - const T *dy, const T *x, const T *mean, const T *var, const T *gamma, const T *grad_dx, - const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, const T *share_mem, - T *global_sum1, T *global_sum2) { + const T *dy, const T *x, const float *mean, const float *var, const T *gamma, + const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, + const T *share_mem, T *global_sum1, T *global_sum2) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = (row * col_dim + col); int gamma_offset = pos % param_dim; @@ -317,15 +317,15 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int & template <> inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, - const half *dy, const half *x, const half *mean, const half *var, const half *gamma, + const half *dy, const half *x, const float *mean, const float *var, const half *gamma, const half *grad_dx, const half *grad_dg, const half *grad_db, half *d_dy, half *d_x, const half *share_mem, half *global_sum1, half *global_sum2) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = (row * col_dim + col); int gamma_offset = pos % param_dim; - half v1 = x[pos] - mean[row]; - half v2 = my_pow(var[row] + epsilon, -0.5); + half v1 = x[pos] - __float2half(mean[row]); + half v2 = my_pow(__float2half(var[row]) + epsilon, -0.5); half v3 = v1 * v2; half part1 = gamma[gamma_offset] * grad_dx[pos] * v2; @@ -334,8 +334,8 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int & half part4 = v3 * grad_dg[gamma_offset]; d_dy[pos] = part1 + part2 + part3 + part4 + grad_db[gamma_offset]; - half part5 = - v1 * (my_pow(var[row] + epsilon, -1.5) * ((share_mem[4] + share_mem[5]) * __float2half(-1.0 / col_dim))); + half part5 = v1 * (my_pow(__float2half(var[row]) + epsilon, -1.5) * + ((share_mem[4] + share_mem[5]) * __float2half(-1.0 / col_dim))); d_x[pos] += part5 + share_mem[6] * __float2half(1.0 / col_dim); global_sum1[pos] = share_mem[0] * __float2half(1.0 / col_dim); @@ -345,7 +345,7 @@ inline __device__ void InputProp(const int &row, const int &col_dim, const int & template __global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *dy, - const T *x, const T *mean, const T *var, const T *gamma, const T *grad_dx, + const T *x, const float *mean, const float *var, const T *gamma, const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *global_sum1, T *global_sum2) { for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { T sum1 = 0; @@ -373,9 +373,9 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int template void CalLayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1, T *global_sum2, - const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, const T *gamma, - const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *d_gamma, - cudaStream_t stream) { + const T &epsilon, const T *dy, const T *x, const float *mean, const float *var, + const T *gamma, const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, + T *d_gamma, cudaStream_t stream) { int share_mem_size = THREAD_PER_BLOCK / WARP_SIZE * NUM_SHARED_SUM_INPUT * sizeof(T); InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, grad_dx, grad_dg, grad_db, @@ -394,7 +394,7 @@ template CUDA_LIB_EXPORT void CalLayerNormGradGrad(const int &row_dim, const int cudaStream_t stream); template CUDA_LIB_EXPORT void CalLayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, half *global_sum1, half *global_sum2, const half &epsilon, - const half *dy, const half *x, const half *mean, const half *var, + const half *dy, const half *x, const float *mean, const float *var, const half *gamma, const half *grad_dx, const half *grad_dg, const half *grad_db, half *d_dy, half *d_x, half *d_gamma, cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cuh index 4114226f7ca..ff5ef1552e6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/layer_norm_grad_grad_impl.cuh @@ -20,8 +20,8 @@ template CUDA_LIB_EXPORT void CalLayerNormGradGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, T *global_sum1, - T *global_sum2, const T &epsilon, const T *dy, const T *x, const T *mean, - const T *var, const T *gamma, const T *grad_dx, const T *grad_dg, + T *global_sum2, const T &epsilon, const T *dy, const T *x, const float *mean, + const float *var, const T *gamma, const T *grad_dx, const T *grad_dg, const T *grad_db, T *d_dy, T *d_x, T *d_gamma, cudaStream_t stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAYER_NORM_GRAD_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/layer_norm_grad_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/layer_norm_grad_grad_gpu_kernel.cc index c3669868c61..fef5c72e2ab 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/layer_norm_grad_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/layer_norm_grad_grad_gpu_kernel.cc @@ -50,8 +50,8 @@ const std::vector> kernel {KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) diff --git a/mindspore/core/ops/grad/layer_norm_grad_grad.cc b/mindspore/core/ops/grad/layer_norm_grad_grad.cc index beb356303fb..1191683a735 100644 --- a/mindspore/core/ops/grad/layer_norm_grad_grad.cc +++ b/mindspore/core/ops/grad/layer_norm_grad_grad.cc @@ -35,17 +35,6 @@ AbstractBasePtr LayerNormGradGradInfer(const abstract::AnalysisEnginePtr &, cons MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); // x MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); // dy MS_EXCEPTION_IF_NULL(input_args[kInputIndex4]); // gamma - const std::set valid_types = {kFloat16, kFloat32}; - std::map types; - (void)types.emplace("x", input_args[kInputIndex0]->BuildType()); - (void)types.emplace("dy", input_args[kInputIndex1]->BuildType()); - (void)types.emplace("variance", input_args[kInputIndex2]->BuildType()); - (void)types.emplace("mean", input_args[kInputIndex3]->BuildType()); - (void)types.emplace("gamma", input_args[kInputIndex4]->BuildType()); - (void)types.emplace("d_dx", input_args[kInputIndex5]->BuildType()); - (void)types.emplace("d_dg", input_args[kInputIndex6]->BuildType()); - (void)types.emplace("d_db", input_args[kInputIndex7]->BuildType()); - (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; auto d_dx_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];