math binary op C++ infer type supports complex data type

This commit is contained in:
looop5 2022-12-08 17:23:46 +08:00
parent 813187cf04
commit 492223da92
7 changed files with 71 additions and 8 deletions

View File

@ -40,7 +40,7 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
auto output_type = CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, prim_name);
if (output_shape->IsDimZero()) {
output_type = input_args[0]->BuildType();
}

View File

@ -54,8 +54,7 @@ class DivInfer : public abstract::OpInferBase {
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name());
return input_args[0]->BuildType();
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types_with_complex, op_name);
}
};

View File

@ -49,8 +49,7 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[0]->BuildType();
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, op_name);
}
} // namespace

View File

@ -49,8 +49,7 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[0]->BuildType();
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, op_name);
}
} // namespace

View File

@ -40,7 +40,7 @@ TypePtr SubInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, valid_types, prim->name());
}
} // namespace

View File

@ -21,6 +21,8 @@
#include <algorithm>
#include <typeinfo>
#include <functional>
#include <set>
#include <map>
#include "abstract/abstract_value.h"
#include "ops/op_utils.h"
@ -554,6 +556,67 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
return CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
}
TypePtr CheckAndConvertUtils::CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types,
const std::set<TypePtr> &check_list,
const std::string &prim_name) {
constexpr size_t n = 2;
if (types.size() != n) {
MS_EXCEPTION(ArgumentError) << "For primitive[" << prim_name << "], the size of types to check must be " << n
<< ", but got " << types.size();
}
// Check Input type is tensor type
std::vector<TypeId> type_ids;
std::vector<TypePtr> type_ptr;
bool has_complex = false;
for (const auto &item : types) {
MS_EXCEPTION_IF_NULL(item.second);
if (!item.second->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input arguments[" << item.first
<< "] must be Tensor, but got " << item.second->ToString();
}
auto tensor_type = item.second->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto type_id = element->type_id();
if (!has_complex && (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128)) {
has_complex = true;
}
type_ids.push_back(type_id);
type_ptr.push_back(item.second);
}
// Deal with complex data type
if (has_complex) {
static std::map<std::pair<TypeId, TypeId>, TypeId> type_infer_dict = {
{{kNumberTypeComplex64, kNumberTypeComplex64}, kNumberTypeComplex64},
{{kNumberTypeComplex64, kNumberTypeFloat32}, kNumberTypeComplex64},
{{kNumberTypeFloat32, kNumberTypeComplex64}, kNumberTypeComplex64},
{{kNumberTypeComplex128, kNumberTypeComplex128}, kNumberTypeComplex128},
{{kNumberTypeComplex128, kNumberTypeFloat64}, kNumberTypeComplex128},
{{kNumberTypeFloat64, kNumberTypeComplex128}, kNumberTypeComplex128}};
std::pair<TypeId, TypeId> type_info(type_ids[0], type_ids[1]);
auto iter = type_infer_dict.find(type_info);
if (iter != type_infer_dict.end()) {
return type_ids[0] == iter->second ? type_ptr[0] : type_ptr[1];
}
std::ostringstream buffer;
buffer << "For primitive[" << prim_name << "], complex math binary op expecting Tensor";
for (const auto &items : type_infer_dict) {
buffer << "[" << TypeIdToString(items.first.first) << ", " << TypeIdToString(items.first.second) << "], ";
}
buffer << "but got Tensor[" << TypeIdToString(type_ids[0]) << ", " << TypeIdToString(type_ids[1]) << "]";
MS_EXCEPTION(TypeError) << buffer.str();
}
// Deal with non-complex data type
if (type_ids[0] != type_ids[1]) {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
<< "], the input arguments must have same data type, but got Tensor["
<< TypeIdToString(type_ids[0]) << "] and Tensor[" << TypeIdToString(type_ids[1]) << "]";
}
(void)CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
return types.begin()->second;
}
ShapeVector CheckAndConvertUtils::CheckTensorShapeSame(const std::map<std::string, BaseShapePtr> &shapes,
const std::vector<int64_t> &check_shape,
const std::string &prim_name) {

View File

@ -289,6 +289,9 @@ class MS_CORE_API CheckAndConvertUtils {
const std::vector<int64_t> &check_shape, const std::string &prim_name);
static TypePtr CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypePtr> &check_list,
const std::string &prim_name);
// Return Tensor type
static TypePtr CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types,
const std::set<TypePtr> &check_list, const std::string &prim_name);
static ShapeVector CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
const std::string &prim_name);
static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,