!39855 [assistant][ops] Add Angle
Merge pull request !39855 from QingZhai/angle
This commit is contained in:
commit
ad6ef7adcb
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "angle_impl.cuh"
|
||||
#include <math.h>
|
||||
|
||||
template <typename S>
|
||||
__global__ void Angle(const size_t size, const Complex<S> *input, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += gridDim.x * blockDim.x) {
|
||||
output[pos] = atan2(input[pos].imag(), input[pos].real());
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S>
|
||||
void CalAngle(const size_t size, T *input, S *output, const uint32_t device_id, cudaStream_t cuda_stream) {
|
||||
Angle<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAngle<Complex<float>, float>(const size_t size, Complex<float> *input, float *output,
|
||||
const uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalAngle<Complex<double>, double>(const size_t size, Complex<double> *input,
|
||||
double *output, const uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ANGLE_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ANGLE_IMPL_CUH_
|
||||
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalAngle(const size_t size, T *input, S *output, const uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_Angle_IMPL_CUH_
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/angle_gpu_kernel.h"
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "abstract/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void AngleGpuKernelMod::ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AngleGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, AngleFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
bool AngleGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "'support complex64 or complex128, but got " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
input_dtype_ = inputs[0]->GetDtype();
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int AngleGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool AngleGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
T *input_ptr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
S *output_ptr = GetDeviceAddress<S>(outputs, kIndex0);
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
output_size = outputs[0]->size / sizeof(S);
|
||||
CalAngle(output_size, input_ptr, output_ptr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
std::vector<std::pair<KernelAttr, AngleGpuKernelMod::AngleFunc>> AngleGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&AngleGpuKernelMod::LaunchKernel<Complex<float>, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
&AngleGpuKernelMod::LaunchKernel<Complex<double>, double>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Angle, AngleGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "ops/complex.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/angle_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
class AngleGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
AngleGpuKernelMod() = default;
|
||||
~AngleGpuKernelMod() override = default;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
||||
using AngleFunc = std::function<bool(AngleGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||
|
||||
private:
|
||||
bool is_null_input_{false};
|
||||
std::string kernel_name_{kUnknown};
|
||||
TypeId input_dtype_ = kNumberTypeComplex64;
|
||||
size_t output_size;
|
||||
AngleFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, AngleFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_ANGLE_GPU_KERNEL_H_
|
|
@ -5939,7 +5939,7 @@ class Angle(Primitive):
|
|||
TypeError: If the dtype of input is not one of: complex64, complex128.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor([-1.5 + 7.8j, 3 + 5.75j], mindspore.complex64)
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import math_ops as P
|
||||
|
||||
|
||||
class NetAngle(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.angle = P.Angle()
|
||||
|
||||
def construct(self, a):
|
||||
return self.angle(a)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_angle_pynative():
|
||||
"""
|
||||
Feature: Angle
|
||||
Description: The input tensor. types: complex64, complex128
|
||||
Expectation: success: return a Tensor, has the float32 or float64 type and the same shape as input.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex64)
|
||||
net = NetAngle()
|
||||
output = net(Tensor(x_np))
|
||||
expect = np.angle(x_np)
|
||||
assert np.allclose(output.asnumpy(), expect, 1e-4, 1e-4)
|
||||
|
||||
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex128)
|
||||
net = NetAngle()
|
||||
output = net(Tensor(x_np))
|
||||
expect = np.angle(x_np)
|
||||
assert np.allclose(output.asnumpy(), expect, 1e-5, 1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_angle_graph():
|
||||
"""
|
||||
Feature: Angle
|
||||
Description: The input tensor. types: complex64, complex128
|
||||
Expectation: success: return a Tensor, has the float32 or float64 type and the same shape as input.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex64)
|
||||
net = NetAngle()
|
||||
output = net(Tensor(x_np))
|
||||
expect = np.angle(x_np)
|
||||
assert np.allclose(output.asnumpy(), expect, 1e-4, 1e-4)
|
||||
|
||||
x_np = np.array([-2.25 + 4.75j, 3.25 + 5.75j]).astype(np.complex128)
|
||||
net = NetAngle()
|
||||
output = net(Tensor(x_np))
|
||||
expect = np.angle(x_np)
|
||||
assert np.allclose(output.asnumpy(), expect, 1e-5, 1e-5)
|
Loading…
Reference in New Issue