From 8f4cd76582ca2c6a33bdf6f431ddf94eddf6c983 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Thu, 18 Jun 2020 19:33:27 +0800 Subject: [PATCH] gpu Gelu kernel support fp16 --- .../ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu | 101 +++++++++++++++--- .../ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc | 7 ++ mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc | 2 + tests/st/ops/gpu/test_gelu_grad_op.py | 32 +++++- tests/st/ops/gpu/test_gelu_op.py | 13 +++ 5 files changed, 139 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu index bb476179d54..e460caec9e4 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu @@ -14,32 +14,62 @@ * limitations under the License. */ - #include "kernel/gpu/cuda_impl/gelu_impl.cuh" #include "device/gpu/cuda_common.h" -template -__global__ void GeluKernel(size_t size, T* input_addr, T* output_addr) { +template +__global__ void GeluKernel(size_t size, T *input_addr, T *output_addr) { // formula: // gelu(x) = 0.5 * x * (1.0 + tanh(y)) // tanh(y) = 2 / (1 + exp(-2y)) - 1) // y = sqrt(2/pi) * (x + 0.044715 * x^3) - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { float x = input_addr[pos]; float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); output_addr[pos] = 0.5 * x * (1.0 + tanh_res); } } -template -void Gelu(size_t size, T* input_addr, T* output_addr, cudaStream_t cuda_stream) { +template <> +__global__ void GeluKernel(size_t size, half *input_addr, half *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half x = input_addr[pos]; + float tanh_res = tanh(__half2float(half(0.7978845608) * (x + half(0.044715) * x * x * x))); + output_addr[pos] = half(0.5) * x * (half(1.0) + __float2half(tanh_res)); + } +} + +template <> +__global__ void GeluKernel(size_t size, half2 *input_addr, half2 *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half2 x = input_addr[pos]; + float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); + float2 tanh_res; + tanh_res.x = tanh(tanh_param.x); + tanh_res.y = tanh(tanh_param.y); + output_addr[pos] = half2(0.5, 0.5) * x * (half2(1.0, 1.0) + __float22half2_rn(tanh_res)); + } +} + +template +void Gelu(size_t size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { GeluKernel<<>>(size, input_addr, output_addr); return; } +template <> +void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream) { + if (size % 2 == 0) { + GeluKernel<<>>( + size / 2, reinterpret_cast(input_addr), reinterpret_cast(output_addr)); + } else { + GeluKernel<<>>(size, input_addr, output_addr); + } + return; +} -template -__global__ void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr) { +template +__global__ void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr) { // formula: // dx = dy * y' // y' = 0.5 * (1 + tanh(tanh_para)) + @@ -48,18 +78,59 @@ __global__ void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr) { // mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)) for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { T x = x_addr[pos]; - T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); - T mul_right = 0.7978845608 + 0.1070322244 * x * x; - T y_res = 0.5 * (1 + tanh_res) + 0.5 * x * (1 - tanh_res * tanh_res) * mul_right; + T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + T mul_right = 0.7978845608 + 0.1070322244 * x * x; + T y_res = 0.5 * (1.0 + tanh_res) + 0.5 * x * (1.0 - tanh_res * tanh_res) * mul_right; dx_addr[pos] = dy_addr[pos] * y_res; } } -template -void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream) { +template +__global__ void GeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, half2 *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + half2 x = x_addr[pos]; + float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); + float2 tanh_res; + tanh_res.x = tanh(tanh_param.x); + tanh_res.y = tanh(tanh_param.y); + half2 tanh_res_half = __float22half2_rn(tanh_res); + half2 mul_right = half2(0.7978845608, 0.7978845608) + half2(0.1070322244, 0.1070322244) * x * x; + half2 y_res = half2(0.5, 0.5) * (half2(1.0, 1.0) + tanh_res_half) + + half2(0.5, 0.5) * x * (half2(1.0, 1.0) - tanh_res_half * tanh_res_half) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +__global__ void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + half x = x_addr[pos]; + half tanh_param = half(0.7978845608) * (x + half(0.044715) * x * x * x); + half tanh_res = __float2half_rn(tanh(__half2float(tanh_param))); + half mul_right = half(0.7978845608) + half(0.1070322244) * x * x; + half y_res = half(0.5) * (half(1.0) + tanh_res) + half(0.5) * x * (half(1.0) - tanh_res * tanh_res) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); } +template <> +void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { + if (size % 2 == 0) { + GeluGradKernel<<>>( + size / 2, reinterpret_cast(dy_addr), reinterpret_cast(x_addr), + reinterpret_cast(dx_addr)); + } else { + GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); + } + return; +} -template void Gelu(size_t size, float* input_addr, float* output_addr, cudaStream_t cuda_stream); -template void GeluGradKernel(size_t size, float* dy_addr, float* x_addr, float* dx_addr, cudaStream_t cuda_stream); +template void Gelu(size_t size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); +template void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, float *dy_addr, float *x_addr, float *dx_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc index 2b6c53aa28c..32d91be80a6 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc @@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_ONE(GeluGrad, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), GeLUGpuGradKernel, float) +MS_REG_GPU_KERNEL_ONE(GeluGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + GeLUGpuGradKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc index 604dee04c4d..ca54ff68ad8 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc @@ -20,5 +20,7 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), GeluGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GeluGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_gelu_grad_op.py b/tests/st/ops/gpu/test_gelu_grad_op.py index 24137c241dc..82145b9d3ff 100644 --- a/tests/st/ops/gpu/test_gelu_grad_op.py +++ b/tests/st/ops/gpu/test_gelu_grad_op.py @@ -58,7 +58,37 @@ def test_gelugrad(): grad = Grad(net) output = grad(x_ms, dy_ms) - print(output) expect = [0.50963277, 0.9414753, 0.2667653, 0.21358444, 0.25243032, 0.0352667, 0.34266686, 0.57757664, 0.04707306, 0.51536125] assert np.allclose(output[0].asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelugrad_fp16(): + np.random.seed(42) + x_np = np.random.randn(5, 3, 6).astype(np.float16) + dy_np = np.random.randn(5, 3, 6).astype(np.float16) + net = GeluNet() + grad = Grad(net) + output = grad(Tensor(x_np), Tensor(dy_np)) + expect = [[[8.4045e-02, 3.7817e-01, -6.6748e-01, -3.6914e-01, -1.2415e-01, -4.6362e-01], + [3.3301e-01, 2.6270e-01, 7.7534e-04, -2.0947e-01, -2.2021e-01, -6.4880e-02], + [-2.3633e-01, 7.6538e-02, 1.8280e-02, 3.8635e-02, -1.6235e-01, 1.2964e-01]], + + [[-1.4801e-02, 9.6130e-03, -2.1660e+00, -8.5602e-03, 3.3356e-02, -3.1885e-01], + [-2.0355e-02, 1.7737e-01, 3.8719e-03, -9.1895e-01, 8.4717e-02, 2.0593e-01], + [5.8350e-02, -1.0020e+00, 6.8652e-01, 1.3428e-01, 6.0352e-01, -2.6270e-01]], + + [[-6.5820e-01, 5.1147e-02, -1.2650e-02, -3.2983e-01, -1.5410e+00, 4.3518e-02], + [-4.3359e-01, 1.2659e-01, 1.1792e-01, 2.2705e-02, -1.2329e-01, -3.5278e-01], + [6.2109e-01, 1.3611e-01, 1.7041e-01, 2.7124e-01, -5.5908e-02, 1.7212e-01]], + + [[2.8320e-01, 8.3252e-01, 4.2480e-02, -3.4473e-01, 3.9429e-01, 3.1958e-01], + [3.6499e-02, 1.2250e-01, 7.1350e-02, -2.7267e-02, 3.0029e-01, -8.0566e-01], + [8.2617e-01, 5.1367e-01, -9.2480e-01, 3.3203e-02, -7.5684e-01, 8.8623e-01]], + + [[5.4590e-01, -9.2383e-01, -2.8107e-02, 4.2432e-01, 4.6826e-01, 5.0879e-01], + [-1.4062e-01, 6.6284e-02, -2.9126e-01, -6.3086e-01, -8.6975e-02, 4.1504e-02], + [-6.3171e-03, 1.0852e-01, 1.3779e-02, 1.0947e+00, -3.0334e-02, 2.3828e+00]]] + assert np.allclose(output[0].asnumpy(), expect, rtol=1e-2) diff --git a/tests/st/ops/gpu/test_gelu_op.py b/tests/st/ops/gpu/test_gelu_op.py index d56f3e662d0..ec8e0041db8 100644 --- a/tests/st/ops/gpu/test_gelu_op.py +++ b/tests/st/ops/gpu/test_gelu_op.py @@ -91,3 +91,16 @@ def test_gelu_neg(): y_ms = net(x_ms) assert np.allclose(y_np, y_ms.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gelu_4d_fp16(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float16) + y_np = GeluCompute(x_np) + + x_ms = Tensor(x_np) + net = GeluNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy(), rtol=1e-3)