diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu index 8b3173e3a9d..35d200b92a6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -35,8 +35,8 @@ inline __device__ half my_pow(half a, double b) { template inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, - const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, - T *dg, T *db) { + const int &mean_dim, const T &epsilon, const T *dy, const T *x, + const T *mean, const T *var, T *dg, T *db) { int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -46,7 +46,8 @@ inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_d } int pos = row * col_dim + col; - dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); + int mean_offset = pos / mean_dim; + dg[0] += dy[pos] * my_pow(var[mean_offset] + epsilon, -0.5) * (x[pos] - mean[mean_offset]); db[0] += dy[pos]; } } @@ -89,8 +90,9 @@ inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_di } template -__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T *dy, const T *x, - const T *mean_addr, const T *var_addr, T *dg_addr, T *db_addr) { +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const int mean_dim, const T epsilon, + const T *dy, const T *x, const T *mean_addr, const T *var_addr, T *dg_addr, + T *db_addr) { // row: [0:param_axis] // col: [param_axis:] // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) @@ -98,7 +100,7 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { T dg = 0; T db = 0; - GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); + GammaAndBetaThreadReduce(col, row_dim, col_dim, mean_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); GammaAndBetaWarpReduce(&dg, &db); GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); } @@ -239,8 +241,12 @@ void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, mean, var, gamma, dx); share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T); - GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, - var, dg, db); + // GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, + // mean, + // var, dg, db); + int param_reduce_dim = row_dim * col_dim / param_dim; + GammaAndBetaPropKernel<<>>(param_reduce_dim, param_dim, col_dim, + epsilon, dy, x, mean, var, dg, db); } template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, diff --git a/tests/st/ops/gpu/test_layer_norm_grad_op.py b/tests/st/ops/gpu/test_layer_norm_grad_op.py index 81e4dcc868c..f7a91e7cdf7 100644 --- a/tests/st/ops/gpu/test_layer_norm_grad_op.py +++ b/tests/st/ops/gpu/test_layer_norm_grad_op.py @@ -193,3 +193,29 @@ def test_layernormgrad4(): assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad5(): + begin_norm_axis = 2 + begin_params_axis = 1 + x_np = np.random.randn(128, 2, 16, 32).astype(np.float32) + dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) diff --git a/tests/st/ops/gpu/test_layer_norm_op.py b/tests/st/ops/gpu/test_layer_norm_op.py index 8be3ca1ffa6..040bc2c1bc6 100644 --- a/tests/st/ops/gpu/test_layer_norm_op.py +++ b/tests/st/ops/gpu/test_layer_norm_op.py @@ -175,3 +175,25 @@ def test_layernorm2d_3(): assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm2d_4(): + begin_norm_axis = 2 + begin_params_axis = 1 + np.random.seed(42) + x_np = np.random.randn(128, 2, 16, 32).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)