From 7257cf129a746ba979f6a33dd7177ad358d64fde Mon Sep 17 00:00:00 2001 From: z00512249 Date: Fri, 21 Jan 2022 14:40:13 +0800 Subject: [PATCH] add complex functions for cpu backend. --- .../cpu/unary_op_cpu_kernel.cc | 96 ++++++++++++++ .../kernel_compiler/cpu/unary_op_cpu_kernel.h | 120 +++++++++++++++++ .../mindspore/ops/operations/math_ops.py | 27 ++-- tests/st/ops/cpu/test_unary_op.py | 123 ++++++++++++++++++ 4 files changed, 353 insertions(+), 13 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_unary_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.cc new file mode 100644 index 00000000000..9067009823e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/unary_op_cpu_kernel.h" +#include +#include +#include +namespace mindspore { +namespace kernel { +template +void Real(const T *input, S *output, size_t start, size_t end) { + if (!std::is_same::value && !std::is_same::value) { + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(std::real(input[i])); + } + } else { + MS_LOG(EXCEPTION) << "For Real, it's output data type only support these types: float or double"; + } +} + +template +void Imag(const T *input, S *output, size_t start, size_t end) { + if constexpr (!std::is_same>::value && !std::is_same>::value) { + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(std::imag(input[i])); + } + } else { + MS_LOG(EXCEPTION) << "For Imag, it's output data type only support these types: float or double"; + } +} + +template +void Conj(const T *input, S *output, size_t start, size_t end) { + if constexpr (std::is_same::value && + (std::is_same::value || std::is_same::value)) { + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(std::conj(input[i])); + } + } else { + MS_LOG(EXCEPTION) << "For Conj, it's output data type only support these types: complex or complex"; + } +} + +template +void UnaryOpCPUKernel::GetUnaryOpFunc() { + if constexpr (std::is_same::value || std::is_same::value) { + static std::map kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real}, + {prim::kPrimImag->name(), &Imag}, + {prim::kPrimConj->name(), &Conj}}; + auto iter = kComplexSupportedTypeMap.find(kernel_name_); + if (iter != kComplexSupportedTypeMap.end()) { + unary_op_func_ = iter->second; + return; + } + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", only support these types: Real, Imag, Conj currently, but got " + << kernel_name_; + } +} + +template +void UnaryOpCPUKernel::InitKernel(const CNodePtr &kernel_node) { + kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + GetUnaryOpFunc(); +} + +template +bool UnaryOpCPUKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + auto input = inputs.front(); + auto output = outputs.front(); + const auto input_addr = reinterpret_cast(input->addr); + auto output_addr = reinterpret_cast(output->addr); + if (unary_op_func_ != nullptr) { + ParallelLaunchAutoSearch( + std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2), + output->size / sizeof(S), this, ¶llel_search_info_); + } else { + (void)memcpy_s(output_addr, output->size, input_addr, input->size); + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.h new file mode 100644 index 00000000000..c4780fce573 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unary_op_cpu_kernel.h @@ -0,0 +1,120 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +using complex64 = std::complex; +using complex128 = std::complex; +namespace mindspore { +namespace kernel { +template +class UnaryOpCPUKernel : public CPUKernel { + public: + UnaryOpCPUKernel() = default; + ~UnaryOpCPUKernel() override = default; + using UnaryOpFunc = std::function; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void GetUnaryOpFunc(); + UnaryOpFunc unary_op_func_{nullptr}; +}; + +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64), + UnaryOpCPUKernel, complex128, double) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32), + UnaryOpCPUKernel, complex64, float) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel, + char, char) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + UnaryOpCPUKernel, int16_t, int16_t) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnaryOpCPUKernel, int32_t, int32_t) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + UnaryOpCPUKernel, int64_t, int64_t) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + UnaryOpCPUKernel, uint16_t, uint16_t) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + UnaryOpCPUKernel, uint32_t, uint32_t) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + UnaryOpCPUKernel, uint64_t, uint64_t) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpCPUKernel, float, float) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpCPUKernel, double, double) +MS_REG_CPU_KERNEL_T_S(Real, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel, + bool, bool) + +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64), + UnaryOpCPUKernel, complex128, double) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32), + UnaryOpCPUKernel, complex64, float) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel, + char, char) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + UnaryOpCPUKernel, int16_t, int16_t) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnaryOpCPUKernel, int32_t, int32_t) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + UnaryOpCPUKernel, int64_t, int64_t) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + UnaryOpCPUKernel, uint16_t, uint16_t) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + UnaryOpCPUKernel, uint32_t, uint32_t) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + UnaryOpCPUKernel, uint64_t, uint64_t) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpCPUKernel, float, float) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpCPUKernel, double, double) +MS_REG_CPU_KERNEL_T_S(Imag, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel, + bool, bool) + +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + UnaryOpCPUKernel, complex128, complex128) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + UnaryOpCPUKernel, complex64, complex64) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), UnaryOpCPUKernel, + char, char) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + UnaryOpCPUKernel, int16_t, int16_t) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnaryOpCPUKernel, int32_t, int32_t) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + UnaryOpCPUKernel, int64_t, int64_t) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + UnaryOpCPUKernel, uint16_t, uint16_t) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + UnaryOpCPUKernel, uint32_t, uint32_t) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + UnaryOpCPUKernel, uint64_t, uint64_t) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpCPUKernel, float, float) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpCPUKernel, double, double) +MS_REG_CPU_KERNEL_T_S(Conj, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), UnaryOpCPUKernel, + bool, bool) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNARY_OP_CPU_KERNEL_H_ diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index f166815e345..b873cf512ed 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -2278,6 +2278,7 @@ class Exp(PrimitiveWithInfer): return Tensor(out) return None + class Einsum(Primitive): """ This operator uses equation to represent a tuple of tensors operations, @@ -2360,6 +2361,7 @@ class Einsum(Primitive): >>> print(output) [[2., 4., 1.], [4., 8., 2.], [6., 12., 3.]] """ + @prim_attr_register def __init__(self, equation): if not isinstance(equation, str): @@ -2370,6 +2372,7 @@ class Einsum(Primitive): self.add_prim_attr('equation', equation) self.init_prim_io_names(inputs=['inputs'], outputs=['output']) + class Expm1(Primitive): r""" Returns exponential then minus 1 of a tensor element-wise. @@ -2406,7 +2409,6 @@ class Expm1(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['y']) - class HistogramFixedWidth(PrimitiveWithInfer): """ Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal @@ -2626,6 +2628,7 @@ class Erfc(Primitive): """Initialize Erfc""" self.init_prim_io_names(inputs=['x'], outputs=['y']) + class Minimum(_MathBinaryOp): r""" Computes the minimum of input tensors element-wise. @@ -2899,7 +2902,6 @@ class DivNoNan(_MathBinaryOp): self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) - class MulNoNan(_MathBinaryOp): r""" Computes `x` * `y` element-wise. If `y` is zero, no matter what `x` is, it will return 0. @@ -3017,6 +3019,7 @@ class FloorDiv(Primitive): """Initialize FloorDiv.""" self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) + class TruncateDiv(Primitive): """ Divides the first input tensor by the second input tensor element-wise for integer types, negative numbers will @@ -3116,6 +3119,7 @@ class TruncateMod(Primitive): """Initialize TruncateMod.""" self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) + class Mod(_MathBinaryOp): r""" Computes the remainder of dividing the first input tensor by the second input tensor element-wise. @@ -3339,6 +3343,7 @@ class Xdivy(Primitive): """Initialize Xdivy.""" self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) + class Xlogy(Primitive): r""" Computes the first input tensor multiplied by the logarithm of second input tensor element-wise. @@ -3431,6 +3436,7 @@ class Acosh(Primitive): """Initialize Acosh""" self.init_prim_io_names(inputs=['x'], outputs=['y']) + class Cosh(Primitive): r""" Computes hyperbolic cosine of input element-wise. @@ -3534,7 +3540,6 @@ class Sinh(Primitive): """Initialize Sinh""" - class _LogicBinaryOp(_BinaryOp): """ Define logic binary operators. @@ -4044,7 +4049,6 @@ class LogicalNot(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['output']) - class LogicalAnd(_LogicBinaryOp): r""" Computes the "logical AND" of two tensors element-wise. @@ -4129,7 +4133,6 @@ class LogicalOr(_LogicBinaryOp): """ - class IsNan(Primitive): r""" Determines which elements are NaN for each position. @@ -4585,7 +4588,6 @@ class Sin(Primitive): """Initialize Sin.""" - class Asin(Primitive): r""" Computes arcsine of input tensors element-wise. @@ -4838,7 +4840,6 @@ class Tan(Primitive): """Initialize Tan""" - class Atan(Primitive): r""" Computes the trigonometric inverse tangent of the input element-wise. @@ -4913,7 +4914,6 @@ class Atanh(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['output']) - class Atan2(_MathBinaryOp): r""" Returns arctangent of x/y element-wise. @@ -5262,7 +5262,6 @@ class BesselI1e(Primitive): self.init_prim_io_names(inputs=['x'], outputs='output') - class Inv(Primitive): r""" Computes Reciprocal of input tensor element-wise. @@ -5297,7 +5296,6 @@ class Inv(Primitive): pass - class Invert(Primitive): r""" Flips all bits of input tensor element-wise. @@ -5680,7 +5678,7 @@ class Conj(PrimitiveWithInfer): TypeError: If the input is not a Tensor. Supported Platforms: - ``GPU`` + ``CPU`` ``GPU`` Examples: >>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64) @@ -5711,7 +5709,7 @@ class Real(PrimitiveWithInfer): TypeError: If the input is not a Tensor. Supported Platforms: - ``GPU`` + ``CPU`` ``GPU`` Examples: >>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64) @@ -5775,7 +5773,7 @@ class Imag(PrimitiveWithInfer): TypeError: If the input is not a Tensor. Supported Platforms: - ``GPU`` + ``CPU`` ``GPU`` Examples: >>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64) @@ -5865,6 +5863,7 @@ class IsClose(Primitive): >>> print(output) [true false false false true] """ + @prim_attr_register def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=True): """Initialize IsClose""" @@ -5924,6 +5923,7 @@ class LuSolve(Primitive): [-1.4000001] [ 0.6 ]] """ + @prim_attr_register def __init__(self): pass @@ -5972,6 +5972,7 @@ class CholeskyInverse(Primitive): [-2.625 1.25 -0.25 ] [ 0.625 -0.25 0.25 ]] """ + @prim_attr_register def __init__(self, upper=False): """Initialize CholeskyInverse""" diff --git a/tests/st/ops/cpu/test_unary_op.py b/tests/st/ops/cpu/test_unary_op.py new file mode 100644 index 00000000000..2935350dea1 --- /dev/null +++ b/tests/st/ops/cpu/test_unary_op.py @@ -0,0 +1,123 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P +from mindspore import dtype as mstype + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_graph_conj(): + """ + Feature: ALL TO ALL + Description: test cases for conj in graph mode cpu backend. + Expectation: the result match numpy conj + """ + context.set_context(mode=context.GRAPH_MODE) + x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64) + ms_x = Tensor(x, mstype.complex64) + output = P.Conj()(ms_x) + expect = np.conj(x) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pynative_conj(): + """ + Feature: ALL TO ALL + Description: test cases for conj in pynative mode cpu backend. + Expectation: the result match numpy conj + """ + context.set_context(mode=context.PYNATIVE_MODE) + x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64) + ms_x = Tensor(x, mstype.complex64) + output = P.Conj()(ms_x) + expect = np.conj(x) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_graph_real(): + """ + Feature: ALL TO ALL + Description: test cases for real in graph mode cpu backend. + Expectation: the result match numpy real + """ + context.set_context(mode=context.GRAPH_MODE) + x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64) + ms_x = Tensor(x, mstype.complex64) + output = P.Real()(ms_x) + expect = np.real(x) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pynative_real(): + """ + Feature: ALL TO ALL + Description: test cases for real in pynative mode cpu backend. + Expectation: the result match numpy real + """ + context.set_context(mode=context.GRAPH_MODE) + x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64) + ms_x = Tensor(x, mstype.complex64) + output = P.Real()(ms_x) + expect = np.real(x) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_graph_imag(): + """ + Feature: ALL TO ALL + Description: test cases for image in graph mode cpu backend. + Expectation: the result match numpy conj + """ + context.set_context(mode=context.GRAPH_MODE) + x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64) + ms_x = Tensor(x, mstype.complex64) + output = P.Imag()(ms_x) + expect = np.imag(x) + assert np.allclose(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pynative_imag(): + """ + Feature: ALL TO ALL + Description: test cases for image in pynative mode cpu backend. + Expectation: the result match numpy image + """ + context.set_context(mode=context.GRAPH_MODE) + x = np.asarray(np.complex(1.3 + 0.4j), dtype=np.complex64) + ms_x = Tensor(x, mstype.complex64) + output = P.Imag()(ms_x) + expect = np.imag(x) + assert np.allclose(output.asnumpy(), expect)