From d8c02dca6d05cf6ba9d16fc4b0814d3e871bb720 Mon Sep 17 00:00:00 2001 From: zwy9901 Date: Mon, 14 Nov 2022 20:16:16 +0800 Subject: [PATCH] add Angle --- .../kernel/cuda_impl/cuda_ops/angle_impl.cu | 37 ++++++++ .../kernel/cuda_impl/cuda_ops/angle_impl.cuh | 28 +++++++ .../gpu/kernel/math/angle_gpu_kernel.cc | 84 +++++++++++++++++++ .../device/gpu/kernel/math/angle_gpu_kernel.h | 74 ++++++++++++++++ .../mindspore/ops/operations/math_ops.py | 2 +- tests/st/ops/gpu/test_angle_op.py | 77 +++++++++++++++++ 6 files changed, 301 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_angle_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cu new file mode 100644 index 00000000000..c174efdf240 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cu @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "angle_impl.cuh" +#include + +template +__global__ void Angle(const size_t size, const Complex *input, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += gridDim.x * blockDim.x) { + output[pos] = atan2(input[pos].imag(), input[pos].real()); + } + return; +} +template +void CalAngle(const size_t size, T *input, S *output, const uint32_t device_id, cudaStream_t cuda_stream) { + Angle<<>>(size, input, output); + return; +} + +template CUDA_LIB_EXPORT void CalAngle, float>(const size_t size, Complex *input, float *output, + const uint32_t device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalAngle, double>(const size_t size, Complex *input, + double *output, const uint32_t device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh new file mode 100644 index 00000000000..38f86535edb --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ANGLE_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ANGLE_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +template +CUDA_LIB_EXPORT void CalAngle(const size_t size, T *input, S *output, const uint32_t device_id, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_Angle_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.cc new file mode 100644 index 00000000000..4590df41f9d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/angle_gpu_kernel.h" +#include +#include +#include +#include +#include +#include "abstract/utils.h" + +namespace mindspore { +namespace kernel { +void AngleGpuKernelMod::ResetResource() noexcept { + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +std::vector AngleGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +bool AngleGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "'support complex64 or complex128, but got " << kernel_attr; + return false; + } + input_dtype_ = inputs[0]->GetDtype(); + kernel_func_ = func_list_[index].second; + return true; +} + +int AngleGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); + return ret; +} + +template +bool AngleGpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *input_ptr = GetDeviceAddress(inputs, kIndex0); + S *output_ptr = GetDeviceAddress(outputs, kIndex0); + auto cuda_stream = reinterpret_cast(stream_ptr); + output_size = outputs[0]->size / sizeof(S); + CalAngle(output_size, input_ptr, output_ptr, device_id_, reinterpret_cast(cuda_stream)); + return true; +} + +template +using Complex = mindspore::utils::Complex; +std::vector> AngleGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32), + &AngleGpuKernelMod::LaunchKernel, float>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64), + &AngleGpuKernelMod::LaunchKernel, double>}}; + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Angle, AngleGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.h new file mode 100644 index 00000000000..c4bc6294aa0 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/angle_gpu_kernel.h @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "ops/complex.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh" +#include "plugin/device/gpu/kernel/kernel_constants.h" + +namespace mindspore { +namespace kernel { +constexpr auto kUnknown = "Unknown"; +class AngleGpuKernelMod : public NativeGpuKernelMod { + public: + AngleGpuKernelMod() = default; + ~AngleGpuKernelMod() override = default; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; + std::vector GetOpSupport() override; + + private: + void ResetResource() noexcept; + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + + using AngleFunc = std::function &, + const std::vector &, const std::vector &, void *)>; + + private: + bool is_null_input_{false}; + std::string kernel_name_{kUnknown}; + TypeId input_dtype_ = kNumberTypeComplex64; + size_t output_size; + AngleFunc kernel_func_; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_ diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 437329069ba..83702a8a55e 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -5859,7 +5859,7 @@ class Angle(Primitive): TypeError: If the dtype of input is not one of: complex64, complex128. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input = Tensor([-1.5 + 7.8j, 3 + 5.75j], mindspore.complex64) diff --git a/tests/st/ops/gpu/test_angle_op.py b/tests/st/ops/gpu/test_angle_op.py new file mode 100644 index 00000000000..2bd2b9517fa --- /dev/null +++ b/tests/st/ops/gpu/test_angle_op.py @@ -0,0 +1,77 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context + +from mindspore import Tensor +from mindspore.ops.operations import math_ops as P + + +class NetAngle(nn.Cell): + def __init__(self): + super().__init__() + self.angle = P.Angle() + + def construct(self, a): + return self.angle(a) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_angle_pynative(): + """ + Feature: Angle + Description: The input tensor. types: complex64, complex128 + Expectation: success: return a Tensor, has the float32 or float64 type and the same shape as input. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex64) + net = NetAngle() + output = net(Tensor(x_np)) + expect = np.angle(x_np) + assert np.allclose(output.asnumpy(), expect, 1e-4, 1e-4) + + x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex128) + net = NetAngle() + output = net(Tensor(x_np)) + expect = np.angle(x_np) + assert np.allclose(output.asnumpy(), expect, 1e-5, 1e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_angle_graph(): + """ + Feature: Angle + Description: The input tensor. types: complex64, complex128 + Expectation: success: return a Tensor, has the float32 or float64 type and the same shape as input. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex64) + net = NetAngle() + output = net(Tensor(x_np)) + expect = np.angle(x_np) + assert np.allclose(output.asnumpy(), expect, 1e-4, 1e-4) + + x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex128) + net = NetAngle() + output = net(Tensor(x_np)) + expect = np.angle(x_np) + assert np.allclose(output.asnumpy(), expect, 1e-5, 1e-5)