forked from mindspore-Ecosystem/mindspore
add complex functions for cpu backend.
This commit is contained in:
parent
9d20bce3b4
commit
7257cf129a
|
@ -0,0 +1,96 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/cpu/unary_op_cpu_kernel.h"
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T, typename S>
|
||||||
|
void Real(const T *input, S *output, size_t start, size_t end) {
|
||||||
|
if (!std::is_same<S, complex64>::value && !std::is_same<S, complex128>::value) {
|
||||||
|
for (size_t i = start; i < end; ++i) {
|
||||||
|
output[i] = static_cast<S>(std::real(input[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "For Real, it's output data type only support these types: float or double";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
void Imag(const T *input, S *output, size_t start, size_t end) {
|
||||||
|
if constexpr (!std::is_same<S, std::complex<float>>::value && !std::is_same<S, std::complex<double>>::value) {
|
||||||
|
for (size_t i = start; i < end; ++i) {
|
||||||
|
output[i] = static_cast<S>(std::imag(input[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "For Imag, it's output data type only support these types: float or double";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
void Conj(const T *input, S *output, size_t start, size_t end) {
|
||||||
|
if constexpr (std::is_same<T, S>::value &&
|
||||||
|
(std::is_same<T, complex64>::value || std::is_same<T, complex128>::value)) {
|
||||||
|
for (size_t i = start; i < end; ++i) {
|
||||||
|
output[i] = static_cast<S>(std::conj(input[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "For Conj, it's output data type only support these types: complex<float> or complex<double>";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
void UnaryOpCPUKernel<T, S>::GetUnaryOpFunc() {
|
||||||
|
if constexpr (std::is_same<T, complex64>::value || std::is_same<T, complex128>::value) {
|
||||||
|
static std::map<std::string, UnaryOpFunc> kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real<T, S>},
|
||||||
|
{prim::kPrimImag->name(), &Imag<T, S>},
|
||||||
|
{prim::kPrimConj->name(), &Conj<T, S>}};
|
||||||
|
auto iter = kComplexSupportedTypeMap.find(kernel_name_);
|
||||||
|
if (iter != kComplexSupportedTypeMap.end()) {
|
||||||
|
unary_op_func_ = iter->second;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", only support these types: Real, Imag, Conj currently, but got "
|
||||||
|
<< kernel_name_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
void UnaryOpCPUKernel<T, S>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
GetUnaryOpFunc();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
bool UnaryOpCPUKernel<T, S>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
auto input = inputs.front();
|
||||||
|
auto output = outputs.front();
|
||||||
|
const auto input_addr = reinterpret_cast<T *>(input->addr);
|
||||||
|
auto output_addr = reinterpret_cast<S *>(output->addr);
|
||||||
|
if (unary_op_func_ != nullptr) {
|
||||||
|
ParallelLaunchAutoSearch(
|
||||||
|
std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2),
|
||||||
|
output->size / sizeof(S), this, ¶llel_search_info_);
|
||||||
|
} else {
|
||||||
|
(void)memcpy_s(output_addr, output->size, input_addr, input->size);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,120 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||||
|
|
||||||
|
using complex64 = std::complex<float>;
|
||||||
|
using complex128 = std::complex<double>;
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T, typename S>
|
||||||
|
class UnaryOpCPUKernel : public CPUKernel {
|
||||||
|
public:
|
||||||
|
UnaryOpCPUKernel() = default;
|
||||||
|
~UnaryOpCPUKernel() override = default;
|
||||||
|
using UnaryOpFunc = std::function<void(const T *, S *, size_t, size_t)>;
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void GetUnaryOpFunc();
|
||||||
|
UnaryOpFunc unary_op_func_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
UnaryOpCPUKernel, complex128, double)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
UnaryOpCPUKernel, complex64, float)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel,
|
||||||
|
char, char)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
UnaryOpCPUKernel, int16_t, int16_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
UnaryOpCPUKernel, int32_t, int32_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
UnaryOpCPUKernel, int64_t, int64_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
UnaryOpCPUKernel, uint16_t, uint16_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
UnaryOpCPUKernel, uint32_t, uint32_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
UnaryOpCPUKernel, uint64_t, uint64_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
UnaryOpCPUKernel, float, float)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
UnaryOpCPUKernel, double, double)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel,
|
||||||
|
bool, bool)
|
||||||
|
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
UnaryOpCPUKernel, complex128, double)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
UnaryOpCPUKernel, complex64, float)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel,
|
||||||
|
char, char)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
UnaryOpCPUKernel, int16_t, int16_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
UnaryOpCPUKernel, int32_t, int32_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
UnaryOpCPUKernel, int64_t, int64_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
UnaryOpCPUKernel, uint16_t, uint16_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
UnaryOpCPUKernel, uint32_t, uint32_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
UnaryOpCPUKernel, uint64_t, uint64_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
UnaryOpCPUKernel, float, float)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
UnaryOpCPUKernel, double, double)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel,
|
||||||
|
bool, bool)
|
||||||
|
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||||
|
UnaryOpCPUKernel, complex128, complex128)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||||
|
UnaryOpCPUKernel, complex64, complex64)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel,
|
||||||
|
char, char)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
UnaryOpCPUKernel, int16_t, int16_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
UnaryOpCPUKernel, int32_t, int32_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
UnaryOpCPUKernel, int64_t, int64_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
UnaryOpCPUKernel, uint16_t, uint16_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
UnaryOpCPUKernel, uint32_t, uint32_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
UnaryOpCPUKernel, uint64_t, uint64_t)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
UnaryOpCPUKernel, float, float)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
UnaryOpCPUKernel, double, double)
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel,
|
||||||
|
bool, bool)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_
|
|
@ -2278,6 +2278,7 @@ class Exp(PrimitiveWithInfer):
|
||||||
return Tensor(out)
|
return Tensor(out)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Einsum(Primitive):
|
class Einsum(Primitive):
|
||||||
"""
|
"""
|
||||||
This operator uses equation to represent a tuple of tensors operations,
|
This operator uses equation to represent a tuple of tensors operations,
|
||||||
|
@ -2360,6 +2361,7 @@ class Einsum(Primitive):
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[[2., 4., 1.], [4., 8., 2.], [6., 12., 3.]]
|
[[2., 4., 1.], [4., 8., 2.], [6., 12., 3.]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, equation):
|
def __init__(self, equation):
|
||||||
if not isinstance(equation, str):
|
if not isinstance(equation, str):
|
||||||
|
@ -2370,6 +2372,7 @@ class Einsum(Primitive):
|
||||||
self.add_prim_attr('equation', equation)
|
self.add_prim_attr('equation', equation)
|
||||||
self.init_prim_io_names(inputs=['inputs'], outputs=['output'])
|
self.init_prim_io_names(inputs=['inputs'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class Expm1(Primitive):
|
class Expm1(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Returns exponential then minus 1 of a tensor element-wise.
|
Returns exponential then minus 1 of a tensor element-wise.
|
||||||
|
@ -2406,7 +2409,6 @@ class Expm1(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HistogramFixedWidth(PrimitiveWithInfer):
|
class HistogramFixedWidth(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal
|
Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal
|
||||||
|
@ -2626,6 +2628,7 @@ class Erfc(Primitive):
|
||||||
"""Initialize Erfc"""
|
"""Initialize Erfc"""
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||||
|
|
||||||
|
|
||||||
class Minimum(_MathBinaryOp):
|
class Minimum(_MathBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
Computes the minimum of input tensors element-wise.
|
Computes the minimum of input tensors element-wise.
|
||||||
|
@ -2899,7 +2902,6 @@ class DivNoNan(_MathBinaryOp):
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MulNoNan(_MathBinaryOp):
|
class MulNoNan(_MathBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
Computes `x` * `y` element-wise. If `y` is zero, no matter what `x` is, it will return 0.
|
Computes `x` * `y` element-wise. If `y` is zero, no matter what `x` is, it will return 0.
|
||||||
|
@ -3017,6 +3019,7 @@ class FloorDiv(Primitive):
|
||||||
"""Initialize FloorDiv."""
|
"""Initialize FloorDiv."""
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class TruncateDiv(Primitive):
|
class TruncateDiv(Primitive):
|
||||||
"""
|
"""
|
||||||
Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will
|
Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will
|
||||||
|
@ -3116,6 +3119,7 @@ class TruncateMod(Primitive):
|
||||||
"""Initialize TruncateMod."""
|
"""Initialize TruncateMod."""
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class Mod(_MathBinaryOp):
|
class Mod(_MathBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
Computes the remainder of dividing the first input tensor by the second input tensor element-wise.
|
Computes the remainder of dividing the first input tensor by the second input tensor element-wise.
|
||||||
|
@ -3339,6 +3343,7 @@ class Xdivy(Primitive):
|
||||||
"""Initialize Xdivy."""
|
"""Initialize Xdivy."""
|
||||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class Xlogy(Primitive):
|
class Xlogy(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Computes the first input tensor multiplied by the logarithm of second input tensor element-wise.
|
Computes the first input tensor multiplied by the logarithm of second input tensor element-wise.
|
||||||
|
@ -3431,6 +3436,7 @@ class Acosh(Primitive):
|
||||||
"""Initialize Acosh"""
|
"""Initialize Acosh"""
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||||
|
|
||||||
|
|
||||||
class Cosh(Primitive):
|
class Cosh(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Computes hyperbolic cosine of input element-wise.
|
Computes hyperbolic cosine of input element-wise.
|
||||||
|
@ -3534,7 +3540,6 @@ class Sinh(Primitive):
|
||||||
"""Initialize Sinh"""
|
"""Initialize Sinh"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _LogicBinaryOp(_BinaryOp):
|
class _LogicBinaryOp(_BinaryOp):
|
||||||
"""
|
"""
|
||||||
Define logic binary operators.
|
Define logic binary operators.
|
||||||
|
@ -4044,7 +4049,6 @@ class LogicalNot(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LogicalAnd(_LogicBinaryOp):
|
class LogicalAnd(_LogicBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
Computes the "logical AND" of two tensors element-wise.
|
Computes the "logical AND" of two tensors element-wise.
|
||||||
|
@ -4129,7 +4133,6 @@ class LogicalOr(_LogicBinaryOp):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class IsNan(Primitive):
|
class IsNan(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Determines which elements are NaN for each position.
|
Determines which elements are NaN for each position.
|
||||||
|
@ -4585,7 +4588,6 @@ class Sin(Primitive):
|
||||||
"""Initialize Sin."""
|
"""Initialize Sin."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Asin(Primitive):
|
class Asin(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Computes arcsine of input tensors element-wise.
|
Computes arcsine of input tensors element-wise.
|
||||||
|
@ -4838,7 +4840,6 @@ class Tan(Primitive):
|
||||||
"""Initialize Tan"""
|
"""Initialize Tan"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Atan(Primitive):
|
class Atan(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Computes the trigonometric inverse tangent of the input element-wise.
|
Computes the trigonometric inverse tangent of the input element-wise.
|
||||||
|
@ -4913,7 +4914,6 @@ class Atanh(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Atan2(_MathBinaryOp):
|
class Atan2(_MathBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
Returns arctangent of x/y element-wise.
|
Returns arctangent of x/y element-wise.
|
||||||
|
@ -5262,7 +5262,6 @@ class BesselI1e(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs='output')
|
self.init_prim_io_names(inputs=['x'], outputs='output')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Inv(Primitive):
|
class Inv(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Computes Reciprocal of input tensor element-wise.
|
Computes Reciprocal of input tensor element-wise.
|
||||||
|
@ -5297,7 +5296,6 @@ class Inv(Primitive):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Invert(Primitive):
|
class Invert(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Flips all bits of input tensor element-wise.
|
Flips all bits of input tensor element-wise.
|
||||||
|
@ -5680,7 +5678,7 @@ class Conj(PrimitiveWithInfer):
|
||||||
TypeError: If the input is not a Tensor.
|
TypeError: If the input is not a Tensor.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU``
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||||
|
@ -5711,7 +5709,7 @@ class Real(PrimitiveWithInfer):
|
||||||
TypeError: If the input is not a Tensor.
|
TypeError: If the input is not a Tensor.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU``
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||||
|
@ -5775,7 +5773,7 @@ class Imag(PrimitiveWithInfer):
|
||||||
TypeError: If the input is not a Tensor.
|
TypeError: If the input is not a Tensor.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU``
|
``CPU`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||||
|
@ -5865,6 +5863,7 @@ class IsClose(Primitive):
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[true false false false true]
|
[true false false false true]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=True):
|
def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=True):
|
||||||
"""Initialize IsClose"""
|
"""Initialize IsClose"""
|
||||||
|
@ -5924,6 +5923,7 @@ class LuSolve(Primitive):
|
||||||
[-1.4000001]
|
[-1.4000001]
|
||||||
[ 0.6 ]]
|
[ 0.6 ]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -5972,6 +5972,7 @@ class CholeskyInverse(Primitive):
|
||||||
[-2.625 1.25 -0.25 ]
|
[-2.625 1.25 -0.25 ]
|
||||||
[ 0.625 -0.25 0.25 ]]
|
[ 0.625 -0.25 0.25 ]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, upper=False):
|
def __init__(self, upper=False):
|
||||||
"""Initialize CholeskyInverse"""
|
"""Initialize CholeskyInverse"""
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_graph_conj():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for conj in graph mode cpu backend.
|
||||||
|
Expectation: the result match numpy conj
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||||
|
ms_x = Tensor(x, mstype.complex64)
|
||||||
|
output = P.Conj()(ms_x)
|
||||||
|
expect = np.conj(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_conj():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for conj in pynative mode cpu backend.
|
||||||
|
Expectation: the result match numpy conj
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||||
|
ms_x = Tensor(x, mstype.complex64)
|
||||||
|
output = P.Conj()(ms_x)
|
||||||
|
expect = np.conj(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_graph_real():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for real in graph mode cpu backend.
|
||||||
|
Expectation: the result match numpy real
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||||
|
ms_x = Tensor(x, mstype.complex64)
|
||||||
|
output = P.Real()(ms_x)
|
||||||
|
expect = np.real(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_real():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for real in pynative mode cpu backend.
|
||||||
|
Expectation: the result match numpy real
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||||
|
ms_x = Tensor(x, mstype.complex64)
|
||||||
|
output = P.Real()(ms_x)
|
||||||
|
expect = np.real(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_graph_imag():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for image in graph mode cpu backend.
|
||||||
|
Expectation: the result match numpy conj
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||||
|
ms_x = Tensor(x, mstype.complex64)
|
||||||
|
output = P.Imag()(ms_x)
|
||||||
|
expect = np.imag(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_pynative_imag():
|
||||||
|
"""
|
||||||
|
Feature: ALL TO ALL
|
||||||
|
Description: test cases for image in pynative mode cpu backend.
|
||||||
|
Expectation: the result match numpy image
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64)
|
||||||
|
ms_x = Tensor(x, mstype.complex64)
|
||||||
|
output = P.Imag()(ms_x)
|
||||||
|
expect = np.imag(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
Loading…
Reference in New Issue