Tensor support complex input
This commit is contained in:
parent
44775dca4b
commit
4e07873378
|
@ -37,6 +37,8 @@ static const std::map<int32_t, int32_t> kMsProtoDataTypeMap = {
|
||||||
{mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32},
|
{mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32},
|
||||||
{mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32},
|
{mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32},
|
||||||
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
|
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
|
||||||
|
{mindspore::TypeId::kNumberTypeComplex64, mindspore::DataType::MS_COMPLEX64},
|
||||||
|
{mindspore::TypeId::kNumberTypeComplex128, mindspore::DataType::MS_COMPLEX128},
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
|
static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
|
||||||
|
@ -53,6 +55,8 @@ static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
|
||||||
{mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16},
|
{mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16},
|
||||||
{mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32},
|
{mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32},
|
||||||
{mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64},
|
{mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64},
|
||||||
|
{mindspore::DataType::MS_COMPLEX64, mindspore::TypeId::kNumberTypeComplex64},
|
||||||
|
{mindspore::DataType::MS_COMPLEX128, mindspore::TypeId::kNumberTypeComplex128},
|
||||||
};
|
};
|
||||||
|
|
||||||
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {
|
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {
|
||||||
|
|
|
@ -456,6 +456,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
||||||
return TensorPy::MakeTensor(py::array(input), type_ptr);
|
return TensorPy::MakeTensor(py::array(input), type_ptr);
|
||||||
}),
|
}),
|
||||||
py::arg("input"), py::arg("dtype") = nullptr)
|
py::arg("input"), py::arg("dtype") = nullptr)
|
||||||
|
// We only suppot array/bool_/int_/float_/list/tuple/complex pybind objects as tensor input,
|
||||||
|
// and array/bool_/int_/float_/list/tuple init will be matched above, other pybind objects
|
||||||
|
// input will raise error except complex data type.
|
||||||
|
.def(py::init([](const py::object &input, const TypePtr &type_ptr) {
|
||||||
|
if (!PyComplex_CheckExact(input.ptr())) {
|
||||||
|
MS_LOG(EXCEPTION) << "Unsupported tensor type: " << input.get_type();
|
||||||
|
}
|
||||||
|
return TensorPy::MakeTensor(py::array(input), type_ptr);
|
||||||
|
}),
|
||||||
|
py::arg("input"), py::arg("dtype") = nullptr)
|
||||||
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
|
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
|
||||||
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
|
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
|
||||||
Get the tensor's data type.
|
Get the tensor's data type.
|
||||||
|
|
|
@ -89,8 +89,8 @@ class Tensor(Tensor_):
|
||||||
|
|
||||||
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
|
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
|
||||||
if init is None:
|
if init is None:
|
||||||
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
|
validator.check_value_type('input_data', input_data,
|
||||||
'Tensor')
|
(Tensor_, np.ndarray, list, tuple, float, int, bool, complex), 'Tensor')
|
||||||
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||||
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
||||||
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
||||||
|
|
|
@ -94,6 +94,14 @@ def test_tensor_type_complex64_user_define():
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_type_complex128():
|
def test_tensor_type_complex128():
|
||||||
|
#complex python object
|
||||||
|
py_input = 1 + 2.22222222j
|
||||||
|
t_complex128 = ms.Tensor(py_input)
|
||||||
|
assert t_complex128.shape == ()
|
||||||
|
assert t_complex128.dtype == ms.complex128
|
||||||
|
assert np.all(t_complex128.asnumpy() == py_input)
|
||||||
|
|
||||||
|
#complex in numpy array
|
||||||
np_input = np.array(
|
np_input = np.array(
|
||||||
[[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128)
|
[[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128)
|
||||||
t_complex128 = ms.Tensor(np_input)
|
t_complex128 = ms.Tensor(np_input)
|
||||||
|
@ -101,10 +109,19 @@ def test_tensor_type_complex128():
|
||||||
assert t_complex128.shape == (2, 3)
|
assert t_complex128.shape == (2, 3)
|
||||||
assert t_complex128.dtype == ms.complex128
|
assert t_complex128.dtype == ms.complex128
|
||||||
assert np.all(t_complex128.asnumpy() == np_input)
|
assert np.all(t_complex128.asnumpy() == np_input)
|
||||||
np_input = (1, 2.22222222j, 3)
|
|
||||||
t_complex128 = ms.Tensor(np_input)
|
|
||||||
assert np.all(t_complex128.asnumpy() == np_input)
|
|
||||||
|
|
||||||
|
#complex in tuple
|
||||||
|
py_input = (1, 2.22222222j, 3)
|
||||||
|
t_complex128 = ms.Tensor(py_input)
|
||||||
|
assert np.all(t_complex128.asnumpy() == py_input)
|
||||||
|
|
||||||
|
#complex in list
|
||||||
|
py_input = [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]]
|
||||||
|
t_complex128 = ms.Tensor(py_input)
|
||||||
|
assert isinstance(t_complex128, ms.Tensor)
|
||||||
|
assert t_complex128.shape == (2, 3)
|
||||||
|
assert t_complex128.dtype == ms.complex128
|
||||||
|
assert np.all(t_complex128.asnumpy() == py_input)
|
||||||
|
|
||||||
def test_tensor_type_complex128_user_define():
|
def test_tensor_type_complex128_user_define():
|
||||||
np_input = np.zeros([1, 2, 3])
|
np_input = np.zeros([1, 2, 3])
|
||||||
|
|
Loading…
Reference in New Issue