diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc index cc5a11f8bb2..4f52e119f1f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc @@ -37,6 +37,8 @@ static const std::map kMsProtoDataTypeMap = { {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, + {mindspore::TypeId::kNumberTypeComplex64, mindspore::DataType::MS_COMPLEX64}, + {mindspore::TypeId::kNumberTypeComplex128, mindspore::DataType::MS_COMPLEX128}, }; static const std::map kProtoDataTypeToMsDataTypeMap = { @@ -53,6 +55,8 @@ static const std::map kProtoDataTypeToMsDataTypeMap = { {mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16}, {mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32}, {mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64}, + {mindspore::DataType::MS_COMPLEX64, mindspore::TypeId::kNumberTypeComplex64}, + {mindspore::DataType::MS_COMPLEX128, mindspore::TypeId::kNumberTypeComplex128}, }; int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 694df059fdd..a809e284d90 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -456,6 +456,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { return TensorPy::MakeTensor(py::array(input), type_ptr); }), py::arg("input"), py::arg("dtype") = nullptr) + // We only suppot array/bool_/int_/float_/list/tuple/complex pybind objects as tensor input, + // and array/bool_/int_/float_/list/tuple init will be matched above, other pybind objects + // input will raise error except complex data type. + .def(py::init([](const py::object &input, const TypePtr &type_ptr) { + if (!PyComplex_CheckExact(input.ptr())) { + MS_LOG(EXCEPTION) << "Unsupported tensor type: " << input.get_type(); + } + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) .def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter( Get the tensor's data type. diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index fdca6e5cdf5..69d5caed8bb 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -89,8 +89,8 @@ class Tensor(Tensor_): # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. if init is None: - validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), - 'Tensor') + validator.check_value_type('input_data', input_data, + (Tensor_, np.ndarray, list, tuple, float, int, bool, complex), 'Tensor') valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128) if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \ diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 6f5ef0233ac..9acb4254996 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -94,6 +94,14 @@ def test_tensor_type_complex64_user_define(): def test_tensor_type_complex128(): + #complex python object + py_input = 1 + 2.22222222j + t_complex128 = ms.Tensor(py_input) + assert t_complex128.shape == () + assert t_complex128.dtype == ms.complex128 + assert np.all(t_complex128.asnumpy() == py_input) + + #complex in numpy array np_input = np.array( [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128) t_complex128 = ms.Tensor(np_input) @@ -101,10 +109,19 @@ def test_tensor_type_complex128(): assert t_complex128.shape == (2, 3) assert t_complex128.dtype == ms.complex128 assert np.all(t_complex128.asnumpy() == np_input) - np_input = (1, 2.22222222j, 3) - t_complex128 = ms.Tensor(np_input) - assert np.all(t_complex128.asnumpy() == np_input) + #complex in tuple + py_input = (1, 2.22222222j, 3) + t_complex128 = ms.Tensor(py_input) + assert np.all(t_complex128.asnumpy() == py_input) + + #complex in list + py_input = [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]] + t_complex128 = ms.Tensor(py_input) + assert isinstance(t_complex128, ms.Tensor) + assert t_complex128.shape == (2, 3) + assert t_complex128.dtype == ms.complex128 + assert np.all(t_complex128.asnumpy() == py_input) def test_tensor_type_complex128_user_define(): np_input = np.zeros([1, 2, 3])