From f9a865fd42a26c92e968e6e6dd3eade173bddd9f Mon Sep 17 00:00:00 2001 From: wukesong Date: Fri, 14 Aug 2020 16:31:45 +0800 Subject: [PATCH] add GPU operator --- .../gpu/cuda_impl/unary_op_impl.cu | 42 +++++++++++++++++++ .../gpu/cuda_impl/unary_op_impl.cuh | 4 ++ .../gpu/math/unary_op_gpu_kernel.cc | 8 ++++ .../gpu/math/unary_op_gpu_kernel.h | 12 ++++++ tests/st/ops/gpu/test_cos_op.py | 33 +++++++++++++++ tests/st/ops/gpu/test_sin_op.py | 33 +++++++++++++++ 6 files changed, 132 insertions(+) create mode 100644 tests/st/ops/gpu/test_cos_op.py create mode 100644 tests/st/ops/gpu/test_sin_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 629c4c29dc..6b8b27a6f3 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -95,6 +95,34 @@ __global__ void RsqrtKernel(half *input, half *output, size_t count) { return; } template +__global__ void SinKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = sin(input[i]); + } + return; +} +template <> +__global__ void SinKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hsin(input[i]); + } + return; +} +template +__global__ void CosKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = cos(input[i]); + } + return; +} +template <> +__global__ void CosKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hcos(input[i]); + } + return; +} +template __global__ void ZeroslikeKernel(T *output, size_t count) { T zero = 0.0; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -167,6 +195,16 @@ void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { return; } template +void Sin(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + SinKernel<<>>(input, output, count); + return; +} +template +void Cos(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + CosKernel<<>>(input, output, count); + return; +} +template void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { RsqrtKernel<<>>(input, output, count); return; @@ -193,6 +231,8 @@ template void Negative(float *input, float *output, size_t count, cudaStr template void Reciprocal(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Square(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Sqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Sin(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Cos(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); template void Abs(float *input, float *output, size_t count, cudaStream_t cuda_stream); @@ -203,6 +243,8 @@ template void Negative(half *input, half *output, size_t count, cudaStream template void Reciprocal(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Square(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Sqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Sin(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Cos(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); template void Abs(half *input, half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index 4020f93df2..538db043a7 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -33,6 +33,10 @@ void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); template +void Sin(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Cos(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); template void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index d646ef417c..6a3206bb18 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -46,6 +46,14 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 7e3f2c862e..49e4b98109 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -36,6 +36,8 @@ enum UnaryOptype { UNARY_OP_SQUARE, UNARY_OP_SQRT, UNARY_OP_RSQRT, + UNARY_OP_SIN, + UNARY_OP_COS, UNARY_OP_ABS, UNARY_OP_FLOOR, UNARY_OP_INVALID_TYPE = 255 @@ -48,6 +50,8 @@ static const std::map kUnaryOpTypeMap = {{"Exp", UNARY {"Square", UNARY_OP_SQUARE}, {"Sqrt", UNARY_OP_SQRT}, {"Rsqrt", UNARY_OP_RSQRT}, + {"Sin", UNARY_OP_SIN}, + {"Cos", UNARY_OP_COS}, {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}}; template @@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel { Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_SIN: { + Sin(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_COS: { + Cos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_ZEROSLIKE: { Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; diff --git a/tests/st/ops/gpu/test_cos_op.py b/tests/st/ops/gpu/test_cos_op.py new file mode 100644 index 0000000000..4feb1aec09 --- /dev/null +++ b/tests/st/ops/gpu/test_cos_op.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cos(): + x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + output_ms = P.Cos()(Tensor(x_np)) + output_np = np.cos(x_np) + assert np.allclose(output_ms.asnumpy(), output_np) diff --git a/tests/st/ops/gpu/test_sin_op.py b/tests/st/ops/gpu/test_sin_op.py new file mode 100644 index 0000000000..117a7a8811 --- /dev/null +++ b/tests/st/ops/gpu/test_sin_op.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sin(): + x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + output_ms = P.Sin()(Tensor(x_np)) + output_np = np.sin(x_np) + assert np.allclose(output_ms.asnumpy(), output_np)