!22933 Tensor support complex scalar number input
Merge pull request !22933 from zhouyaqiang0/complex_ops
This commit is contained in:
commit
5ed105ec01
|
@ -37,6 +37,8 @@ static const std::map<int32_t, int32_t> kMsProtoDataTypeMap = {
|
|||
{mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32},
|
||||
{mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32},
|
||||
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
|
||||
{mindspore::TypeId::kNumberTypeComplex64, mindspore::DataType::MS_COMPLEX64},
|
||||
{mindspore::TypeId::kNumberTypeComplex128, mindspore::DataType::MS_COMPLEX128},
|
||||
};
|
||||
|
||||
static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
|
||||
|
@ -53,6 +55,8 @@ static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
|
|||
{mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16},
|
||||
{mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32},
|
||||
{mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64},
|
||||
{mindspore::DataType::MS_COMPLEX64, mindspore::TypeId::kNumberTypeComplex64},
|
||||
{mindspore::DataType::MS_COMPLEX128, mindspore::TypeId::kNumberTypeComplex128},
|
||||
};
|
||||
|
||||
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {
|
||||
|
|
|
@ -456,6 +456,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
return TensorPy::MakeTensor(py::array(input), type_ptr);
|
||||
}),
|
||||
py::arg("input"), py::arg("dtype") = nullptr)
|
||||
// We only suppot array/bool_/int_/float_/list/tuple/complex pybind objects as tensor input,
|
||||
// and array/bool_/int_/float_/list/tuple init will be matched above, other pybind objects
|
||||
// input will raise error except complex data type.
|
||||
.def(py::init([](const py::object &input, const TypePtr &type_ptr) {
|
||||
if (!PyComplex_CheckExact(input.ptr())) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported tensor type: " << input.get_type();
|
||||
}
|
||||
return TensorPy::MakeTensor(py::array(input), type_ptr);
|
||||
}),
|
||||
py::arg("input"), py::arg("dtype") = nullptr)
|
||||
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
|
||||
.def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter(
|
||||
Get the tensor's data type.
|
||||
|
|
|
@ -89,8 +89,8 @@ class Tensor(Tensor_):
|
|||
|
||||
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
|
||||
if init is None:
|
||||
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
|
||||
'Tensor')
|
||||
validator.check_value_type('input_data', input_data,
|
||||
(Tensor_, np.ndarray, list, tuple, float, int, bool, complex), 'Tensor')
|
||||
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
||||
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
||||
|
|
|
@ -94,6 +94,14 @@ def test_tensor_type_complex64_user_define():
|
|||
|
||||
|
||||
def test_tensor_type_complex128():
|
||||
#complex python object
|
||||
py_input = 1 + 2.22222222j
|
||||
t_complex128 = ms.Tensor(py_input)
|
||||
assert t_complex128.shape == ()
|
||||
assert t_complex128.dtype == ms.complex128
|
||||
assert np.all(t_complex128.asnumpy() == py_input)
|
||||
|
||||
#complex in numpy array
|
||||
np_input = np.array(
|
||||
[[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128)
|
||||
t_complex128 = ms.Tensor(np_input)
|
||||
|
@ -101,10 +109,19 @@ def test_tensor_type_complex128():
|
|||
assert t_complex128.shape == (2, 3)
|
||||
assert t_complex128.dtype == ms.complex128
|
||||
assert np.all(t_complex128.asnumpy() == np_input)
|
||||
np_input = (1, 2.22222222j, 3)
|
||||
t_complex128 = ms.Tensor(np_input)
|
||||
assert np.all(t_complex128.asnumpy() == np_input)
|
||||
|
||||
#complex in tuple
|
||||
py_input = (1, 2.22222222j, 3)
|
||||
t_complex128 = ms.Tensor(py_input)
|
||||
assert np.all(t_complex128.asnumpy() == py_input)
|
||||
|
||||
#complex in list
|
||||
py_input = [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]]
|
||||
t_complex128 = ms.Tensor(py_input)
|
||||
assert isinstance(t_complex128, ms.Tensor)
|
||||
assert t_complex128.shape == (2, 3)
|
||||
assert t_complex128.dtype == ms.complex128
|
||||
assert np.all(t_complex128.asnumpy() == py_input)
|
||||
|
||||
def test_tensor_type_complex128_user_define():
|
||||
np_input = np.zeros([1, 2, 3])
|
||||
|
|
Loading…
Reference in New Issue