Tensor support complex input

This commit is contained in:
zhouyaqiang 2021-09-06 09:45:27 +08:00
parent 44775dca4b
commit 4e07873378
4 changed files with 36 additions and 5 deletions

View File

@ -37,6 +37,8 @@ static const std::map<int32_t, int32_t> kMsProtoDataTypeMap = {
{mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32},
{mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32},
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
{mindspore::TypeId::kNumberTypeComplex64, mindspore::DataType::MS_COMPLEX64},
{mindspore::TypeId::kNumberTypeComplex128, mindspore::DataType::MS_COMPLEX128},
};
static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
@ -53,6 +55,8 @@ static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
{mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16},
{mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32},
{mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64},
{mindspore::DataType::MS_COMPLEX64, mindspore::TypeId::kNumberTypeComplex64},
{mindspore::DataType::MS_COMPLEX128, mindspore::TypeId::kNumberTypeComplex128},
};
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {

View File

@ -456,6 +456,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
return TensorPy::MakeTensor(py::array(input), type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
// We only suppot array/bool_/int_/float_/list/tuple/complex pybind objects as tensor input,
// and array/bool_/int_/float_/list/tuple init will be matched above, other pybind objects
// input will raise error except complex data type.
.def(py::init([](const py::object &input, const TypePtr &type_ptr) {
if (!PyComplex_CheckExact(input.ptr())) {
MS_LOG(EXCEPTION) << "Unsupported tensor type: " << input.get_type();
}
return TensorPy::MakeTensor(py::array(input), type_ptr);
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.

View File

@ -89,8 +89,8 @@ class Tensor(Tensor_):
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None:
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
'Tensor')
validator.check_value_type('input_data', input_data,
(Tensor_, np.ndarray, list, tuple, float, int, bool, complex), 'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \

View File

@ -94,6 +94,14 @@ def test_tensor_type_complex64_user_define():
def test_tensor_type_complex128():
#complex python object
py_input = 1 + 2.22222222j
t_complex128 = ms.Tensor(py_input)
assert t_complex128.shape == ()
assert t_complex128.dtype == ms.complex128
assert np.all(t_complex128.asnumpy() == py_input)
#complex in numpy array
np_input = np.array(
[[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128)
t_complex128 = ms.Tensor(np_input)
@ -101,10 +109,19 @@ def test_tensor_type_complex128():
assert t_complex128.shape == (2, 3)
assert t_complex128.dtype == ms.complex128
assert np.all(t_complex128.asnumpy() == np_input)
np_input = (1, 2.22222222j, 3)
t_complex128 = ms.Tensor(np_input)
assert np.all(t_complex128.asnumpy() == np_input)
#complex in tuple
py_input = (1, 2.22222222j, 3)
t_complex128 = ms.Tensor(py_input)
assert np.all(t_complex128.asnumpy() == py_input)
#complex in list
py_input = [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]]
t_complex128 = ms.Tensor(py_input)
assert isinstance(t_complex128, ms.Tensor)
assert t_complex128.shape == (2, 3)
assert t_complex128.dtype == ms.complex128
assert np.all(t_complex128.asnumpy() == py_input)
def test_tensor_type_complex128_user_define():
np_input = np.zeros([1, 2, 3])