From fdd2d8209fc5093739c129c05d512617bdefed7f Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 9 Sep 2020 20:51:43 -0400 Subject: [PATCH] Support erf and erfc ant GPU backend --- .../kernel_compiler/gpu/cuda_impl/erf_impl.cu | 32 +++++++ .../gpu/cuda_impl/erf_impl.cuh | 25 +++++ .../gpu/cuda_impl/erfc_impl.cu | 32 +++++++ .../gpu/cuda_impl/erfc_impl.cuh | 25 +++++ .../gpu/math/erf_gpu_kernel.cc | 24 +++++ .../kernel_compiler/gpu/math/erf_gpu_kernel.h | 92 +++++++++++++++++++ .../gpu/math/erfc_gpu_kernel.cc | 24 +++++ .../gpu/math/erfc_gpu_kernel.h | 92 +++++++++++++++++++ tests/st/ops/gpu/test_erf_op.py | 46 ++++++++++ tests/st/ops/gpu/test_erfc_op.py | 46 ++++++++++ 10 files changed, 438 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_erf_op.py create mode 100644 tests/st/ops/gpu/test_erfc_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu new file mode 100644 index 00000000000..257c8503b76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "erf_impl.cuh" +template +__global__ void ErfKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = (T)erf(input[i]); + } + return; +} + +template +void Erf(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + ErfKernel<<>>(input, output, count); + return; +} + +template void Erf(float *input, float *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cuh new file mode 100644 index 00000000000..f7c476a30eb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void Erf(T *input, T *output, size_t count, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu new file mode 100644 index 00000000000..6b20cd55376 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "erfc_impl.cuh" +template +__global__ void ErfcKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = (T)erfc(input[i]); + } + return; +} + +template +void Erfc(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + ErfcKernel<<>>(input, output, count); + return; +} + +template void Erfc(float *input, float *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cuh new file mode 100644 index 00000000000..8fccc368617 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void Erfc(T *input, T *output, size_t count, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ERFIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc new file mode 100644 index 00000000000..3531e9cccab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/erf_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ErfGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.h new file mode 100644 index 00000000000..88cc4eb95cb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.h @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/erf_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ErfGpuKernel : public GpuKernel { + public: + ErfGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {} + ~ErfGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + Erf(input_addr, output_addr, outputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but erf needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but erf needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + if (input_size_ != output_size_) { + MS_LOG(ERROR) << "Input size and output should be equal for Erf."; + return false; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc new file mode 100644 index 00000000000..cb63ed6f7f9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/erfc_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ErfcGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.h new file mode 100644 index 00000000000..9d46c792eb2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.h @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ErfcGpuKernel : public GpuKernel { + public: + ErfcGpuKernel() : input_size_(sizeof(T)), output_size_(sizeof(T)) {} + ~ErfcGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + Erfc(input_addr, output_addr, outputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but erfc needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but erfc needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + if (input_size_ != output_size_) { + MS_LOG(ERROR) << "Input size and output should be equal for Erfc."; + return false; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ERF_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_erf_op.py b/tests/st/ops/gpu/test_erf_op.py new file mode 100644 index 00000000000..98c2085137e --- /dev/null +++ b/tests/st/ops/gpu/test_erf_op.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from scipy import special + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetErf(nn.Cell): + def __init__(self): + super(NetErf, self).__init__() + self.erf = P.Erf() + + def construct(self, x): + return self.erf(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_exp(): + erf = NetErf() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + output = erf(Tensor(x, dtype=dtype.float32)) + expect = special.erf(x) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect) < tol).all() diff --git a/tests/st/ops/gpu/test_erfc_op.py b/tests/st/ops/gpu/test_erfc_op.py new file mode 100644 index 00000000000..38be92bf9ce --- /dev/null +++ b/tests/st/ops/gpu/test_erfc_op.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from scipy import special + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetErfc(nn.Cell): + def __init__(self): + super(NetErfc, self).__init__() + self.erfc = P.Erfc() + + def construct(self, x): + return self.erfc(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_exp(): + erfc = NetErfc() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + output = erfc(Tensor(x, dtype=dtype.float32)) + expect = special.erfc(x) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect) < tol).all()