forked from mindspore-Ecosystem/mindspore
add complex functions for cpu backend.
This commit is contained in:
parent
9d20bce3b4
commit
7257cf129a
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/unary_op_cpu_kernel.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
void Real(const T *input, S *output, size_t start, size_t end) {
|
||||
if (!std::is_same<S, complex64>::value && !std::is_same<S, complex128>::value) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::real(input[i]));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For Real, it's output data type only support these types: float or double";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Imag(const T *input, S *output, size_t start, size_t end) {
|
||||
if constexpr (!std::is_same<S, std::complex<float>>::value && !std::is_same<S, std::complex<double>>::value) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::imag(input[i]));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For Imag, it's output data type only support these types: float or double";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Conj(const T *input, S *output, size_t start, size_t end) {
|
||||
if constexpr (std::is_same<T, S>::value &&
|
||||
(std::is_same<T, complex64>::value || std::is_same<T, complex128>::value)) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::conj(input[i]));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For Conj, it's output data type only support these types: complex<float> or complex<double>";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void UnaryOpCPUKernel<T, S>::GetUnaryOpFunc() {
|
||||
if constexpr (std::is_same<T, complex64>::value || std::is_same<T, complex128>::value) {
|
||||
static std::map<std::string, UnaryOpFunc> kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real<T, S>},
|
||||
{prim::kPrimImag->name(), &Imag<T, S>},
|
||||
{prim::kPrimConj->name(), &Conj<T, S>}};
|
||||
auto iter = kComplexSupportedTypeMap.find(kernel_name_);
|
||||
if (iter != kComplexSupportedTypeMap.end()) {
|
||||
unary_op_func_ = iter->second;
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", only support these types: Real, Imag, Conj currently, but got "
|
||||
<< kernel_name_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void UnaryOpCPUKernel<T, S>::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
GetUnaryOpFunc();
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool UnaryOpCPUKernel<T, S>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input = inputs.front();
|
||||
auto output = outputs.front();
|
||||
const auto input_addr = reinterpret_cast<T *>(input->addr);
|
||||
auto output_addr = reinterpret_cast<S *>(output->addr);
|
||||
if (unary_op_func_ != nullptr) {
|
||||
ParallelLaunchAutoSearch(
|
||||
std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2),
|
||||
output->size / sizeof(S), this, ¶llel_search_info_);
|
||||
} else {
|
||||
(void)memcpy_s(output_addr, output->size, input_addr, input->size);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_
|
||||
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class UnaryOpCPUKernel : public CPUKernel {
|
||||
public:
|
||||
UnaryOpCPUKernel() = default;
|
||||
~UnaryOpCPUKernel() override = default;
|
||||
using UnaryOpFunc = std::function<void(const T *, S *, size_t, size_t)>;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void GetUnaryOpFunc();
|
||||
UnaryOpFunc unary_op_func_{nullptr};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpCPUKernel, complex128, double)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpCPUKernel, complex64, float)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel,
|
||||
char, char)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
UnaryOpCPUKernel, int16_t, int16_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpCPUKernel, int32_t, int32_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
UnaryOpCPUKernel, int64_t, int64_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
UnaryOpCPUKernel, uint16_t, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
UnaryOpCPUKernel, uint32_t, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
UnaryOpCPUKernel, uint64_t, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpCPUKernel, float, float)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpCPUKernel, double, double)
|
||||
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel,
|
||||
bool, bool)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpCPUKernel, complex128, double)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpCPUKernel, complex64, float)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel,
|
||||
char, char)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
UnaryOpCPUKernel, int16_t, int16_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpCPUKernel, int32_t, int32_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
UnaryOpCPUKernel, int64_t, int64_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
UnaryOpCPUKernel, uint16_t, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
UnaryOpCPUKernel, uint32_t, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
UnaryOpCPUKernel, uint64_t, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpCPUKernel, float, float)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpCPUKernel, double, double)
|
||||
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel,
|
||||
bool, bool)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
UnaryOpCPUKernel, complex128, complex128)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
UnaryOpCPUKernel, complex64, complex64)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel,
|
||||
char, char)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
UnaryOpCPUKernel, int16_t, int16_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpCPUKernel, int32_t, int32_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
UnaryOpCPUKernel, int64_t, int64_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
UnaryOpCPUKernel, uint16_t, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
UnaryOpCPUKernel, uint32_t, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
UnaryOpCPUKernel, uint64_t, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpCPUKernel, float, float)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpCPUKernel, double, double)
|
||||
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel,
|
||||
bool, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_
|
|
@ -2278,6 +2278,7 @@ class Exp(PrimitiveWithInfer):
|
|||
return Tensor(out)
|
||||
return None
|
||||
|
||||
|
||||
class Einsum(Primitive):
|
||||
"""
|
||||
This operator uses equation to represent a tuple of tensors operations,
|
||||
|
@ -2360,6 +2361,7 @@ class Einsum(Primitive):
|
|||
>>> print(output)
|
||||
[[2., 4., 1.], [4., 8., 2.], [6., 12., 3.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, equation):
|
||||
if not isinstance(equation, str):
|
||||
|
@ -2370,6 +2372,7 @@ class Einsum(Primitive):
|
|||
self.add_prim_attr('equation', equation)
|
||||
self.init_prim_io_names(inputs=['inputs'], outputs=['output'])
|
||||
|
||||
|
||||
class Expm1(Primitive):
|
||||
r"""
|
||||
Returns exponential then minus 1 of a tensor element-wise.
|
||||
|
@ -2406,7 +2409,6 @@ class Expm1(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
|
||||
class HistogramFixedWidth(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal
|
||||
|
@ -2626,6 +2628,7 @@ class Erfc(Primitive):
|
|||
"""Initialize Erfc"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class Minimum(_MathBinaryOp):
|
||||
r"""
|
||||
Computes the minimum of input tensors element-wise.
|
||||
|
@ -2899,7 +2902,6 @@ class DivNoNan(_MathBinaryOp):
|
|||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
|
||||
|
||||
class MulNoNan(_MathBinaryOp):
|
||||
r"""
|
||||
Computes `x` * `y` element-wise. If `y` is zero, no matter what `x` is, it will return 0.
|
||||
|
@ -3017,6 +3019,7 @@ class FloorDiv(Primitive):
|
|||
"""Initialize FloorDiv."""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
|
||||
class TruncateDiv(Primitive):
|
||||
"""
|
||||
Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will
|
||||
|
@ -3116,6 +3119,7 @@ class TruncateMod(Primitive):
|
|||
"""Initialize TruncateMod."""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
|
||||
class Mod(_MathBinaryOp):
|
||||
r"""
|
||||
Computes the remainder of dividing the first input tensor by the second input tensor element-wise.
|
||||
|
@ -3339,6 +3343,7 @@ class Xdivy(Primitive):
|
|||
"""Initialize Xdivy."""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
|
||||
class Xlogy(Primitive):
|
||||
r"""
|
||||
Computes the first input tensor multiplied by the logarithm of second input tensor element-wise.
|
||||
|
@ -3431,6 +3436,7 @@ class Acosh(Primitive):
|
|||
"""Initialize Acosh"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class Cosh(Primitive):
|
||||
r"""
|
||||
Computes hyperbolic cosine of input element-wise.
|
||||
|
@ -3534,7 +3540,6 @@ class Sinh(Primitive):
|
|||
"""Initialize Sinh"""
|
||||
|
||||
|
||||
|
||||
class _LogicBinaryOp(_BinaryOp):
|
||||
"""
|
||||
Define logic binary operators.
|
||||
|
@ -4044,7 +4049,6 @@ class LogicalNot(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
|
||||
|
||||
class LogicalAnd(_LogicBinaryOp):
|
||||
r"""
|
||||
Computes the "logical AND" of two tensors element-wise.
|
||||
|
@ -4129,7 +4133,6 @@ class LogicalOr(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
|
||||
|
||||
class IsNan(Primitive):
|
||||
r"""
|
||||
Determines which elements are NaN for each position.
|
||||
|
@ -4585,7 +4588,6 @@ class Sin(Primitive):
|
|||
"""Initialize Sin."""
|
||||
|
||||
|
||||
|
||||
class Asin(Primitive):
|
||||
r"""
|
||||
Computes arcsine of input tensors element-wise.
|
||||
|
@ -4838,7 +4840,6 @@ class Tan(Primitive):
|
|||
"""Initialize Tan"""
|
||||
|
||||
|
||||
|
||||
class Atan(Primitive):
|
||||
r"""
|
||||
Computes the trigonometric inverse tangent of the input element-wise.
|
||||
|
@ -4913,7 +4914,6 @@ class Atanh(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
|
||||
|
||||
class Atan2(_MathBinaryOp):
|
||||
r"""
|
||||
Returns arctangent of x/y element-wise.
|
||||
|
@ -5262,7 +5262,6 @@ class BesselI1e(Primitive):
|
|||
self.init_prim_io_names(inputs=['x'], outputs='output')
|
||||
|
||||
|
||||
|
||||
class Inv(Primitive):
|
||||
r"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
|
@ -5297,7 +5296,6 @@ class Inv(Primitive):
|
|||
pass
|
||||
|
||||
|
||||
|
||||
class Invert(Primitive):
|
||||
r"""
|
||||
Flips all bits of input tensor element-wise.
|
||||
|
@ -5680,7 +5678,7 @@ class Conj(PrimitiveWithInfer):
|
|||
TypeError: If the input is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||
|
@ -5711,7 +5709,7 @@ class Real(PrimitiveWithInfer):
|
|||
TypeError: If the input is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||
|
@ -5775,7 +5773,7 @@ class Imag(PrimitiveWithInfer):
|
|||
TypeError: If the input is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||
|
@ -5865,6 +5863,7 @@ class IsClose(Primitive):
|
|||
>>> print(output)
|
||||
[true false false false true]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=True):
|
||||
"""Initialize IsClose"""
|
||||
|
@ -5924,6 +5923,7 @@ class LuSolve(Primitive):
|
|||
[-1.4000001]
|
||||
[ 0.6 ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -5972,6 +5972,7 @@ class CholeskyInverse(Primitive):
|
|||
[-2.625 1.25 -0.25 ]
|
||||
[ 0.625 -0.25 0.25 ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, upper=False):
|
||||
"""Initialize CholeskyInverse"""
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_conj():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for conj in graph mode cpu backend.
|
||||
Expectation: the result match numpy conj
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||
ms_x = Tensor(x, mstype.complex64)
|
||||
output = P.Conj()(ms_x)
|
||||
expect = np.conj(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_conj():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for conj in pynative mode cpu backend.
|
||||
Expectation: the result match numpy conj
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||
ms_x = Tensor(x, mstype.complex64)
|
||||
output = P.Conj()(ms_x)
|
||||
expect = np.conj(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_real():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for real in graph mode cpu backend.
|
||||
Expectation: the result match numpy real
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||
ms_x = Tensor(x, mstype.complex64)
|
||||
output = P.Real()(ms_x)
|
||||
expect = np.real(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_real():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for real in pynative mode cpu backend.
|
||||
Expectation: the result match numpy real
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||
ms_x = Tensor(x, mstype.complex64)
|
||||
output = P.Real()(ms_x)
|
||||
expect = np.real(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_imag():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for image in graph mode cpu backend.
|
||||
Expectation: the result match numpy conj
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||
ms_x = Tensor(x, mstype.complex64)
|
||||
output = P.Imag()(ms_x)
|
||||
expect = np.imag(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_imag():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for image in pynative mode cpu backend.
|
||||
Expectation: the result match numpy image
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||
ms_x = Tensor(x, mstype.complex64)
|
||||
output = P.Imag()(ms_x)
|
||||
expect = np.imag(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
Loading…
Reference in New Issue