From 1eee3d69374978e97e19e3aba6f8d5bcf809e49a Mon Sep 17 00:00:00 2001 From: wilfChen Date: Thu, 30 Jul 2020 11:06:55 +0800 Subject: [PATCH] gpu layernorm --- .../gpu/cuda_impl/layer_norm_grad_impl.cu | 90 +++++++++---------- .../gpu/cuda_impl/layer_norm_impl.cu | 14 +-- tests/st/ops/gpu/test_layer_norm_grad_op.py | 52 +++++++++++ tests/st/ops/gpu/test_layer_norm_op.py | 42 +++++++++ 4 files changed, 139 insertions(+), 59 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu index fcb74189520..8b3173e3a9d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -34,9 +34,9 @@ inline __device__ half my_pow(half a, double b) { } template -inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, - const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, - T* dg, T* db) { +inline __device__ void GammaAndBetaThreadReduce(const int &col, const int &row_dim, const int &col_dim, + const T &epsilon, const T *dy, const T *x, const T *mean, const T *var, + T *dg, T *db) { int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -53,7 +53,7 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d } template -inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { +inline __device__ void GammaAndBetaWarpReduce(T *dg, T *db) { for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); db[0] += __shfl_down_sync(0xffffffff, db[0], delta); @@ -61,12 +61,8 @@ inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { } template -inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, - T* db_addr) { - if (threadIdx.x >= row_dim) { - return; - } - +inline __device__ void GammaAndBetaBlockReduce(const int &col, const int &row_dim, T *dg, T *db, T *dg_addr, + T *db_addr) { // load data to share memory // thread(0, 32, 64, 96, ...) keep the data DynamicSharedMem share_mem; @@ -93,8 +89,8 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di } template -__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, - const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T *dy, const T *x, + const T *mean_addr, const T *var_addr, T *dg_addr, T *db_addr) { // row: [0:param_axis] // col: [param_axis:] // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) @@ -109,9 +105,9 @@ __global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, con } template -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, - const T* var, const T* gamma) { +inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, + T *sum1, T *sum2, T *sum3, const T *dy, const T *x, const T *mean, + const T *var, const T *gamma) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -133,9 +129,9 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con } template <> -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - half* sum1, half* sum2, half* sum3, const half* dy, const half* x, - const half* mean, const half* var, const half* gamma) { +inline __device__ void InputThreadReduce(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, + half *sum1, half *sum2, half *sum3, const half *dy, const half *x, + const half *mean, const half *var, const half *gamma) { int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { @@ -157,7 +153,7 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con } template -inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { +inline __device__ void InputWarpReduce(T *sum1, T *sum2, T *sum3) { for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); @@ -166,11 +162,7 @@ inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { } template -inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - +inline __device__ void InputBlockReduce(const int &col_dim, T *sum1, T *sum2, T *sum3, T *share_mem) { // load data to share memory // thread(0, 32, 64, 96, ...) keep the data if (threadIdx.x % WARP_SIZE == 0) { @@ -193,9 +185,9 @@ inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* } template -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, - const T* share_mem) { +inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const T &epsilon, + const T *dy, const T *x, const T *mean, const T *var, const T *gamma, T *dx, + const T *share_mem) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = (row * col_dim + col); int gamma_offset = pos % param_dim; @@ -208,9 +200,9 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& } template <> -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, const half* share_mem) { +inline __device__ void InputProp(const int &row, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *dy, const half *x, const half *mean, const half *var, const half *gamma, + half *dx, const half *share_mem) { for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { int pos = (row * col_dim + col); int gamma_offset = pos % param_dim; @@ -218,14 +210,14 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& half v2 = x[pos] - mean[row]; half v3 = my_pow(var[row] + epsilon, -0.5); dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + - (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ - * __float2half(1.0 / col_dim); + (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2]) * + __float2half(1.0 / col_dim); } } template -__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx) { +__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *dy, + const T *x, const T *mean, const T *var, const T *gamma, T *dx) { for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { T sum1 = 0; T sum2 = 0; @@ -239,21 +231,21 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int } template -void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, - gamma, dx); +void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *dy, + const T *x, const T *mean, const T *var, const T *gamma, T *dx, T *dg, T *db, cudaStream_t stream) { + const int thread_per_block = 256; + int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, + mean, var, gamma, dx); - share_mem_size = - ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); - GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); + share_mem_size = thread_per_block / WARP_SIZE * 2 * sizeof(T); + GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, + var, dg, db); } -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, - const float* dy, const float* x, const float* mean, const float* var, const float* gamma, - float* dx, float* dg, float* db, cudaStream_t stream); -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, half* dg, half* db, cudaStream_t stream); +template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, + const float *dy, const float *x, const float *mean, const float *var, const float *gamma, + float *dx, float *dg, float *db, cudaStream_t stream); +template void LayerNormGrad(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *dy, const half *x, const half *mean, const half *var, const half *gamma, + half *dx, half *dg, half *db, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu index 138300b3034..5797a3d711c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu @@ -73,10 +73,6 @@ inline __device__ void WarpReduce(T *mean, T *var, T *num) { template inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, T *share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - // load data to share memory // thread(0, 32, 64, 96, ...) keep the data if (threadIdx.x % WARP_SIZE == 0) { @@ -146,13 +142,11 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { - const dim3 block(row_dim); - const dim3 thread(256); + const int thread_per_block = 256; // keep the mean/var/num after warp reduce - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, - mean, var); + int share_mem_size = thread_per_block / WARP_SIZE * 3 * sizeof(T); + LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, + beta, y, mean, var); } template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, diff --git a/tests/st/ops/gpu/test_layer_norm_grad_op.py b/tests/st/ops/gpu/test_layer_norm_grad_op.py index 032dee50ac9..81e4dcc868c 100644 --- a/tests/st/ops/gpu/test_layer_norm_grad_op.py +++ b/tests/st/ops/gpu/test_layer_norm_grad_op.py @@ -141,3 +141,55 @@ def test_layernormgrad2(): assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad3(): + begin_norm_axis = -1 + begin_params_axis = -1 + x_np = np.random.randn(32, 64).astype(np.float32) + dy_np = np.random.randn(32, 64).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad4(): + begin_norm_axis = -1 + begin_params_axis = -1 + x_np = np.random.randn(32, 64).astype(np.float32) + dy_np = np.random.randn(32, 64).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, + begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) diff --git a/tests/st/ops/gpu/test_layer_norm_op.py b/tests/st/ops/gpu/test_layer_norm_op.py index 776201735bd..8be3ca1ffa6 100644 --- a/tests/st/ops/gpu/test_layer_norm_op.py +++ b/tests/st/ops/gpu/test_layer_norm_op.py @@ -133,3 +133,45 @@ def test_layernorm3d_2(): assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm2d_2(): + begin_norm_axis = -1 + begin_params_axis = 1 + x_np = np.random.randn(64, 32).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm2d_3(): + begin_norm_axis = -1 + begin_params_axis = 1 + x_np = np.random.randn(128, 128).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)