!39855 [assistant][ops] Add Angle

Merge pull request !39855 from QingZhai/angle
This commit is contained in:
i-robot 2022-11-23 06:45:03 +00:00 committed by Gitee
commit ad6ef7adcb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 301 additions and 1 deletions

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "angle_impl.cuh"
#include <math.h>
template <typename S>
__global__ void Angle(const size_t size, const Complex<S> *input, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += gridDim.x * blockDim.x) {
output[pos] = atan2(input[pos].imag(), input[pos].real());
}
return;
}
template <typename T, typename S>
void CalAngle(const size_t size, T *input, S *output, const uint32_t device_id, cudaStream_t cuda_stream) {
Angle<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input, output);
return;
}
template CUDA_LIB_EXPORT void CalAngle<Complex<float>, float>(const size_t size, Complex<float> *input, float *output,
const uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalAngle<Complex<double>, double>(const size_t size, Complex<double> *input,
double *output, const uint32_t device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,28 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ANGLE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ANGLE_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void CalAngle(const size_t size, T *input, S *output, const uint32_t device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_Angle_IMPL_CUH_

View File

@ -0,0 +1,84 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/angle_gpu_kernel.h"
#include <utility>
#include <functional>
#include <string>
#include <algorithm>
#include <memory>
#include "abstract/utils.h"
namespace mindspore {
namespace kernel {
void AngleGpuKernelMod::ResetResource() noexcept {
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
std::vector<KernelAttr> AngleGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, AngleFunc> &pair) { return pair.first; });
return support_list;
}
bool AngleGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "'support complex64 or complex128, but got " << kernel_attr;
return false;
}
input_dtype_ = inputs[0]->GetDtype();
kernel_func_ = func_list_[index].second;
return true;
}
int AngleGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
return ret;
}
template <typename T, typename S>
bool AngleGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
T *input_ptr = GetDeviceAddress<T>(inputs, kIndex0);
S *output_ptr = GetDeviceAddress<S>(outputs, kIndex0);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
output_size = outputs[0]->size / sizeof(S);
CalAngle(output_size, input_ptr, output_ptr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
return true;
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, AngleGpuKernelMod::AngleFunc>> AngleGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
&AngleGpuKernelMod::LaunchKernel<Complex<float>, float>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
&AngleGpuKernelMod::LaunchKernel<Complex<double>, double>}};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Angle, AngleGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_
#include <memory>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
#include <string>
#include <map>
#include "ops/complex.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
namespace mindspore {
namespace kernel {
constexpr auto kUnknown = "Unknown";
class AngleGpuKernelMod : public NativeGpuKernelMod {
public:
AngleGpuKernelMod() = default;
~AngleGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
void ResetResource() noexcept;
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using AngleFunc = std::function<bool(AngleGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
private:
bool is_null_input_{false};
std::string kernel_name_{kUnknown};
TypeId input_dtype_ = kNumberTypeComplex64;
size_t output_size;
AngleFunc kernel_func_;
static std::vector<std::pair<KernelAttr, AngleFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_

View File

@ -5939,7 +5939,7 @@ class Angle(Primitive):
TypeError: If the dtype of input is not one of: complex64, complex128.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input = Tensor([-1.5 + 7.8j, 3 + 5.75j], mindspore.complex64)

View File

@ -0,0 +1,77 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.operations import math_ops as P
class NetAngle(nn.Cell):
def __init__(self):
super().__init__()
self.angle = P.Angle()
def construct(self, a):
return self.angle(a)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_angle_pynative():
"""
Feature: Angle
Description: The input tensor. types: complex64, complex128
Expectation: success: return a Tensor, has the float32 or float64 type and the same shape as input.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex64)
net = NetAngle()
output = net(Tensor(x_np))
expect = np.angle(x_np)
assert np.allclose(output.asnumpy(), expect, 1e-4, 1e-4)
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex128)
net = NetAngle()
output = net(Tensor(x_np))
expect = np.angle(x_np)
assert np.allclose(output.asnumpy(), expect, 1e-5, 1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_angle_graph():
"""
Feature: Angle
Description: The input tensor. types: complex64, complex128
Expectation: success: return a Tensor, has the float32 or float64 type and the same shape as input.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex64)
net = NetAngle()
output = net(Tensor(x_np))
expect = np.angle(x_np)
assert np.allclose(output.asnumpy(), expect, 1e-4, 1e-4)
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex128)
net = NetAngle()
output = net(Tensor(x_np))
expect = np.angle(x_np)
assert np.allclose(output.asnumpy(), expect, 1e-5, 1e-5)