From 492223da92dd90da1379164016c5acec5daaa7aa Mon Sep 17 00:00:00 2001 From: looop5 Date: Thu, 8 Dec 2022 17:23:46 +0800 Subject: [PATCH] math binary op C++ infer type supports complex data type --- mindspore/core/ops/add.cc | 2 +- mindspore/core/ops/div.cc | 3 +- mindspore/core/ops/mul.cc | 3 +- mindspore/core/ops/real_div.cc | 3 +- mindspore/core/ops/sub.cc | 2 +- mindspore/core/utils/check_convert_utils.cc | 63 +++++++++++++++++++++ mindspore/core/utils/check_convert_utils.h | 3 + 7 files changed, 71 insertions(+), 8 deletions(-) diff --git a/mindspore/core/ops/add.cc b/mindspore/core/ops/add.cc index 4b39c382e77..00108fb5015 100644 --- a/mindspore/core/ops/add.cc +++ b/mindspore/core/ops/add.cc @@ -40,7 +40,7 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name); + auto output_type = CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, prim_name); if (output_shape->IsDimZero()) { output_type = input_args[0]->BuildType(); } diff --git a/mindspore/core/ops/div.cc b/mindspore/core/ops/div.cc index db14378a53e..6d2d82232e3 100644 --- a/mindspore/core/ops/div.cc +++ b/mindspore/core/ops/div.cc @@ -54,8 +54,7 @@ class DivInfer : public abstract::OpInferBase { std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name()); - return input_args[0]->BuildType(); + return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types_with_complex, op_name); } }; diff --git a/mindspore/core/ops/mul.cc b/mindspore/core/ops/mul.cc index f781fd4b693..e22904d7543 100644 --- a/mindspore/core/ops/mul.cc +++ b/mindspore/core/ops/mul.cc @@ -49,8 +49,7 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return input_args[0]->BuildType(); + return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, op_name); } } // namespace diff --git a/mindspore/core/ops/real_div.cc b/mindspore/core/ops/real_div.cc index 417e8136e83..739ffb7c901 100644 --- a/mindspore/core/ops/real_div.cc +++ b/mindspore/core/ops/real_div.cc @@ -49,8 +49,7 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); - return input_args[0]->BuildType(); + return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, op_name); } } // namespace diff --git a/mindspore/core/ops/sub.cc b/mindspore/core/ops/sub.cc index 6376df11134..2b81c558347 100644 --- a/mindspore/core/ops/sub.cc +++ b/mindspore/core/ops/sub.cc @@ -40,7 +40,7 @@ TypePtr SubInferType(const PrimitivePtr &prim, const std::vectorBuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, valid_types, prim->name()); } } // namespace diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index dd6c164f9c8..350355a1494 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include "abstract/abstract_value.h" #include "ops/op_utils.h" @@ -554,6 +556,67 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::mapfirst, types.begin()->second, check_list, prim_name); } +TypePtr CheckAndConvertUtils::CheckMathBinaryOpTensorType(const std::map &types, + const std::set &check_list, + const std::string &prim_name) { + constexpr size_t n = 2; + if (types.size() != n) { + MS_EXCEPTION(ArgumentError) << "For primitive[" << prim_name << "], the size of types to check must be " << n + << ", but got " << types.size(); + } + // Check Input type is tensor type + std::vector type_ids; + std::vector type_ptr; + bool has_complex = false; + for (const auto &item : types) { + MS_EXCEPTION_IF_NULL(item.second); + if (!item.second->isa()) { + MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input arguments[" << item.first + << "] must be Tensor, but got " << item.second->ToString(); + } + auto tensor_type = item.second->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto element = tensor_type->element(); + MS_EXCEPTION_IF_NULL(element); + auto type_id = element->type_id(); + if (!has_complex && (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128)) { + has_complex = true; + } + type_ids.push_back(type_id); + type_ptr.push_back(item.second); + } + // Deal with complex data type + if (has_complex) { + static std::map, TypeId> type_infer_dict = { + {{kNumberTypeComplex64, kNumberTypeComplex64}, kNumberTypeComplex64}, + {{kNumberTypeComplex64, kNumberTypeFloat32}, kNumberTypeComplex64}, + {{kNumberTypeFloat32, kNumberTypeComplex64}, kNumberTypeComplex64}, + {{kNumberTypeComplex128, kNumberTypeComplex128}, kNumberTypeComplex128}, + {{kNumberTypeComplex128, kNumberTypeFloat64}, kNumberTypeComplex128}, + {{kNumberTypeFloat64, kNumberTypeComplex128}, kNumberTypeComplex128}}; + std::pair type_info(type_ids[0], type_ids[1]); + auto iter = type_infer_dict.find(type_info); + if (iter != type_infer_dict.end()) { + return type_ids[0] == iter->second ? type_ptr[0] : type_ptr[1]; + } + std::ostringstream buffer; + buffer << "For primitive[" << prim_name << "], complex math binary op expecting Tensor"; + for (const auto &items : type_infer_dict) { + buffer << "[" << TypeIdToString(items.first.first) << ", " << TypeIdToString(items.first.second) << "], "; + } + buffer << "but got Tensor[" << TypeIdToString(type_ids[0]) << ", " << TypeIdToString(type_ids[1]) << "]"; + MS_EXCEPTION(TypeError) << buffer.str(); + } + // Deal with non-complex data type + if (type_ids[0] != type_ids[1]) { + MS_EXCEPTION(TypeError) << "For primitive[" << prim_name + << "], the input arguments must have same data type, but got Tensor[" + << TypeIdToString(type_ids[0]) << "] and Tensor[" << TypeIdToString(type_ids[1]) << "]"; + } + (void)CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name); + return types.begin()->second; +} + ShapeVector CheckAndConvertUtils::CheckTensorShapeSame(const std::map &shapes, const std::vector &check_shape, const std::string &prim_name) { diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index 71af16c538e..800157681e6 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -289,6 +289,9 @@ class MS_CORE_API CheckAndConvertUtils { const std::vector &check_shape, const std::string &prim_name); static TypePtr CheckTensorTypeSame(const std::map &types, const std::set &check_list, const std::string &prim_name); + // Return Tensor type + static TypePtr CheckMathBinaryOpTensorType(const std::map &types, + const std::set &check_list, const std::string &prim_name); static ShapeVector CheckTensorIntValue(const std::string &type_name, const ValuePtr &value, const std::string &prim_name); static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,