From 53b45295585196307af039fdd03f6bac9213c7fb Mon Sep 17 00:00:00 2001 From: wilfChen Date: Mon, 27 Apr 2020 20:06:47 +0800 Subject: [PATCH] Gpu support LayerNorm kernel --- .../gpu/cuda_impl/layer_norm_grad_impl.cu | 205 ++++++++++++++++++ .../gpu/cuda_impl/layer_norm_grad_impl.cuh | 26 +++ .../kernel/gpu/cuda_impl/layer_norm_impl.cu | 148 +++++++++++++ .../kernel/gpu/cuda_impl/layer_norm_impl.cuh | 26 +++ .../kernel/gpu/nn/layer_norm_gpu_kernel.cc | 31 +++ .../kernel/gpu/nn/layer_norm_gpu_kernel.h | 103 +++++++++ .../gpu/nn/layer_norm_grad_gpu_kernel.cc | 33 +++ .../gpu/nn/layer_norm_grad_gpu_kernel.h | 107 +++++++++ tests/st/ops/gpu/test_layer_norm_grad_op.py | 140 ++++++++++++ tests/st/ops/gpu/test_layer_norm_op.py | 134 ++++++++++++ 10 files changed, 953 insertions(+) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h create mode 100644 mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_layer_norm_grad_op.py create mode 100644 tests/st/ops/gpu/test_layer_norm_op.py diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu new file mode 100644 index 00000000000..f8377fd7219 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -0,0 +1,205 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; + +template +inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, + const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, + T* dg, T* db) { + int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int row = NUM_PER_THREAD_REDUCE * i + j; + if (row >= row_dim) { + return; + } + + int pos = row * col_dim + col; + dg[0] += dy[pos] * pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); + db[0] += dy[pos]; + } + } +} + +template +inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); + db[0] += __shfl_down_sync(0xffffffff, db[0], delta); + } +} + +template +inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, + T* db_addr) { + if (threadIdx.x >= row_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + extern __shared__ T share_mem[]; + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 2; + share_mem[offset] = dg[0]; + share_mem[offset + 1] = db[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 2; + share_mem[threadIdx.x * 2] += share_mem[offset]; + share_mem[threadIdx.x * 2 + 1] += share_mem[offset + 1]; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + dg_addr[col] = share_mem[0]; + db_addr[col] = share_mem[1]; + } +} + +template +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, + const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { + // row: [0:param_axis] + // col: [param_axis:] + // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) + // dg[j] = \Sigma_{j}dg[i][j] + for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { + T dg = 0; + T db = 0; + GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); + GammaAndBetaWarpReduce(&dg, &db); + GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); + } +} + +template +inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, + T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, + const T* var, const T* gamma) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + T v1 = dy[pos] * gamma[gamma_offset]; + T v2 = x[pos] - mean[row]; + + sum1[0] += -0.5 * v1 * v2 * pow(var[row] + epsilon, -1.5); + sum2[0] += v1; + sum3[0] += -2.0 * v2; + } + } +} + +template +inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); + sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); + sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); + } +} + +template +inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { + if (threadIdx.x >= col_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem[offset] = sum1[0]; + share_mem[offset + 1] = sum2[0]; + share_mem[offset + 2] = sum3[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + share_mem[threadIdx.x * 3] += share_mem[offset]; + share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; + share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; + } + } + __syncthreads(); +} + +template +inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, + const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, + const T* share_mem) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + T v1 = dy[pos] * gamma[gamma_offset]; + T v2 = x[pos] - mean[row]; + T v3 = pow(var[row] + epsilon, -0.5); + dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + + (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); + } +} + +template +__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx) { + for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { + T sum1 = 0; + T sum2 = 0; + T sum3 = 0; + extern __shared__ T share_mem[]; + InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); + InputWarpReduce(&sum1, &sum2, &sum3); + InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem); + InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem); + } +} + +template +void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { + int share_mem = + ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, + dx); + + share_mem = + ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); + GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); +} + +template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, + const float* dy, const float* x, const float* mean, const float* var, const float* gamma, + float* dx, float* dg, float* db, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh new file mode 100644 index 00000000000..9f7d57cdb98 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ + +#include "device/gpu/cuda_common.h" + +template +void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu new file mode 100644 index 00000000000..db336737448 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; + +template +inline __device__ void MeanAndVarAccumulation(T* mean, T* var, T* num, const T& val) { + // Welford Algorithm: + // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k + // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) + num[0]++; + T mean_new = mean[0] + (val - mean[0]) / num[0]; + var[0] = var[0] + (val - mean[0]) * (val - mean_new); + mean[0] = mean_new; +} + +template +inline __device__ void MeanAndVarMerge(T* m1, T* v1, T* n1, const T& m2, const T& v2, const T& n2) { + if (n2 == 0) { + return; + } + + T count = n1[0] + n2; + v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count; + m1[0] = (n1[0] * m1[0] + n2 * m2) / count; + n1[0] = count; +} + +template +inline __device__ void ThreadReduce(const int& col_dim, const T* block_addr, T* mean, T* var, T* num) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int pos = NUM_PER_THREAD_REDUCE * i + j; + if (pos >= col_dim) { + return; + } + MeanAndVarAccumulation(mean, var, num, block_addr[pos]); + } + } +} + +template +inline __device__ void WarpReduce(T* mean, T* var, T* num) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); + T var_other = __shfl_down_sync(0xffffffff, var[0], delta); + T num_other = __shfl_down_sync(0xffffffff, num[0], delta); + MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other); + } +} + +template +inline __device__ void BlockReduce(const int& col_dim, T* mean, T* var, T* num, T* mean_addr, T* var_addr, + T* share_mem) { + if (threadIdx.x >= col_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem[offset] = mean[0]; + share_mem[offset + 1] = var[0]; + share_mem[offset + 2] = num[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2], + share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + mean_addr[blockIdx.x] = share_mem[0]; // todo: blockDim.x < row + share_mem[1] /= col_dim; + var_addr[blockIdx.x] = share_mem[1]; + } +} + +template +inline __device__ void LayerNorm(const int& row, const int& col_dim, const int& param_dim, const T* x, + const T* share_mem, const T* gamma, const T* beta, const T epsilon, T* y) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = row * col_dim + col; + int i = pos % param_dim; + y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; + } +} + +template +__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* x, + const T* gamma, const T* beta, T* y, T* mean_addr, T* var_addr) { + for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { + T mean = 0; + T var = 0; + T num = 0; + const T* block_addr = x + row * col_dim; + extern __shared__ T share_mem[]; + + ThreadReduce(col_dim, block_addr, &mean, &var, &num); + WarpReduce(&mean, &var, &num); + BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem); + + __syncthreads(); + LayerNorm(row, col_dim, param_dim, x, share_mem, gamma, beta, epsilon, y); + } +} + +template +void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* x, + const T* gamma, const T* beta, T* y, T* mean, T* var, cudaStream_t stream) { + const dim3 block(row_dim); + const dim3 thread(256); + // keep the mean/var/num after warp reduce + int share_mem = + ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); + LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, mean, + var); +} + +template void LayerNorm(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, + const float* x, const float* gamma, const float* beta, float* y, float* mean, float* var, + cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh new file mode 100644 index 00000000000..4832b087467 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ + +#include "device/gpu/cuda_common.h" + +template +void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, + const T* beta, T* y, T* mean, T* var, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc new file mode 100644 index 00000000000..e67b745ab36 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/layer_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h new file mode 100644 index 00000000000..e80cd091e5b --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGpuKernel : public GpuKernel { + public: + LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} + ~LayerNormGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto gamma = GetDeviceAddress(inputs, 1); + auto beta = GetDeviceAddress(inputs, 2); + auto y = GetDeviceAddress(outputs, 0); + auto mean = GetDeviceAddress(outputs, 1); + auto variance = GetDeviceAddress(outputs, 2); + + T epsilon = 10e-12; + LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); + int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + output_size_list_.push_back(input_row_ * sizeof(T)); + output_size_list_.push_back(input_row_ * sizeof(T)); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc new file mode 100644 index 00000000000..e268161349e --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/layer_norm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h new file mode 100644 index 00000000000..84049206dbf --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGradGpuKernel : public GpuKernel { + public: + LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} + ~LayerNormGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + auto dy = GetDeviceAddress(inputs, 0); + auto x = GetDeviceAddress(inputs, 1); + auto var = GetDeviceAddress(inputs, 2); + auto mean = GetDeviceAddress(inputs, 3); + auto gamma = GetDeviceAddress(inputs, 4); + auto dx = GetDeviceAddress(outputs, 0); + auto dg = GetDeviceAddress(outputs, 1); + auto db = GetDeviceAddress(outputs, 2); + + T epsilon = 10e-12; + LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); + int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + output_size_list_.push_back(param_dim_ * sizeof(T)); + output_size_list_.push_back(param_dim_ * sizeof(T)); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_layer_norm_grad_op.py b/tests/st/ops/gpu/test_layer_norm_grad_op.py new file mode 100644 index 00000000000..0cef113d7cd --- /dev/null +++ b/tests/st/ops/gpu/test_layer_norm_grad_op.py @@ -0,0 +1,140 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C +import mindspore.nn as nn +import mindspore.context as context + + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class LayerNormGradNet(nn.Cell): + def __init__(self, begin_norm_axis, begin_params_axis): + super(LayerNormGradNet, self).__init__() + self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) + + def construct(self, dy, x, var, mean, gamma): + return self.norm(dy, x, var, mean, gamma) + +def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): + begin_norm_axis = begin_norm_axis if begin_norm_axis >=0 else begin_norm_axis + len(x.shape) + begin_params_axis = begin_params_axis if begin_params_axis >=0 else begin_params_axis + len(x.shape) + + norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] + param_axis = [i for i in range(0, begin_params_axis)] + num = 1 + for i in range(begin_norm_axis, len(x.shape)): + num *= x.shape[i] + + mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) + var = np.var(x, axis=tuple(norm_axis), keepdims=True) + + gamma = gamma.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:])) + dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) + db = np.sum(dy, axis=tuple(param_axis), keepdims=True) + + sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), keepdims=True) + sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) + sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) + + dx1 = dy * gamma * np.power(var + epsilon, -0.5) + dx2 = sum1 * 2.0 / num * (x - mean) + dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) + dx = dx1 + dx2 + dx3 + return dx, dg, db, mean, var + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad0(): + begin_norm_axis = 1 + begin_params_axis = 1 + x_np = np.random.randn(4096, 3072).astype(np.float32) + dy_np = np.random.randn(4096, 3072).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms) + + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad1(): + begin_norm_axis = 1 + begin_params_axis = 1 + x_np = np.random.randn(640, 768).astype(np.float32) + dy_np = np.random.randn(640, 768).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms) + + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernormgrad2(): + begin_norm_axis = -1 + begin_params_axis = -1 + x_np = np.random.randn(32, 128, 768).astype(np.float32) + dy_np = np.random.randn(32, 128, 768).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + epsilon = 10e-12 + dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, begin_params_axis) + + dy_ms = Tensor(dy_np) + x_ms = Tensor(x_np) + var_ms = Tensor(var_np) + mean_ms = Tensor(mean_np) + gamma_ms = Tensor(gamma_np) + + net = LayerNormGradNet(begin_norm_axis, begin_params_axis) + dx_ms, dg_ms, db_ms = net(dy_ms, x_ms, var_ms, mean_ms, gamma_ms) + + assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) + assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) + assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) \ No newline at end of file diff --git a/tests/st/ops/gpu/test_layer_norm_op.py b/tests/st/ops/gpu/test_layer_norm_op.py new file mode 100644 index 00000000000..a281cd0f5ff --- /dev/null +++ b/tests/st/ops/gpu/test_layer_norm_op.py @@ -0,0 +1,134 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class LayerNormNet(nn.Cell): + def __init__(self, begin_norm_axis, begin_params_axis): + super(LayerNormNet, self).__init__() + self.norm = P.LayerNorm(begin_norm_axis, begin_params_axis) + + def construct(self, x, gamma, beta): + return self.norm(x, gamma, beta) + +def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta): + begin_norm_axis = begin_norm_axis if begin_norm_axis >=0 else begin_norm_axis + len(x.shape) + begin_params_axis = begin_params_axis if begin_params_axis >=0 else begin_params_axis + len(x.shape) + + axis = [i for i in range(begin_norm_axis, len(x.shape))] + mean = np.mean(x, axis=tuple(axis), keepdims=True) + var = np.var(x, axis=tuple(axis), keepdims=True) + + gamma = gamma.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:])) + beta = beta.reshape((*((1,)*begin_params_axis), *x.shape[begin_params_axis:])) + y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta + return y, mean, var + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm0(): + begin_norm_axis = 1 + begin_params_axis = 1 + x_np = np.random.randn(4096, 3072).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + + assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm1(): + begin_norm_axis = 1 + begin_params_axis = 1 + x_np = np.random.randn(640, 768).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + + + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm3d_1(): + begin_norm_axis = -1 + begin_params_axis = -1 + x_np = np.random.randn(32, 128, 768).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm3d_2(): + begin_norm_axis = -1 + begin_params_axis = 1 + x_np = np.random.randn(32, 128, 768).astype(np.float32) + gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) + y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) + + x_ms = Tensor(x_np) + gamma_ms = Tensor(gamma_np) + beta_ms = Tensor(beta_np) + net = LayerNormNet(begin_norm_axis, begin_params_axis) + y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) + + assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-6) + assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-6) + assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-6)