forked from mindspore-Ecosystem/mindspore
math binary op C++ infer type supports complex data type
This commit is contained in:
parent
813187cf04
commit
492223da92
|
@ -40,7 +40,7 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
auto output_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
|
||||
auto output_type = CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, prim_name);
|
||||
if (output_shape->IsDimZero()) {
|
||||
output_type = input_args[0]->BuildType();
|
||||
}
|
||||
|
|
|
@ -54,8 +54,7 @@ class DivInfer : public abstract::OpInferBase {
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name());
|
||||
return input_args[0]->BuildType();
|
||||
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types_with_complex, op_name);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -49,8 +49,7 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[0]->BuildType();
|
||||
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -49,8 +49,7 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[0]->BuildType();
|
||||
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ TypePtr SubInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
|
|||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return CheckAndConvertUtils::CheckMathBinaryOpTensorType(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <algorithm>
|
||||
#include <typeinfo>
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
@ -554,6 +556,67 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
|
|||
return CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types,
|
||||
const std::set<TypePtr> &check_list,
|
||||
const std::string &prim_name) {
|
||||
constexpr size_t n = 2;
|
||||
if (types.size() != n) {
|
||||
MS_EXCEPTION(ArgumentError) << "For primitive[" << prim_name << "], the size of types to check must be " << n
|
||||
<< ", but got " << types.size();
|
||||
}
|
||||
// Check Input type is tensor type
|
||||
std::vector<TypeId> type_ids;
|
||||
std::vector<TypePtr> type_ptr;
|
||||
bool has_complex = false;
|
||||
for (const auto &item : types) {
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
if (!item.second->isa<TensorType>()) {
|
||||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input arguments[" << item.first
|
||||
<< "] must be Tensor, but got " << item.second->ToString();
|
||||
}
|
||||
auto tensor_type = item.second->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto type_id = element->type_id();
|
||||
if (!has_complex && (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128)) {
|
||||
has_complex = true;
|
||||
}
|
||||
type_ids.push_back(type_id);
|
||||
type_ptr.push_back(item.second);
|
||||
}
|
||||
// Deal with complex data type
|
||||
if (has_complex) {
|
||||
static std::map<std::pair<TypeId, TypeId>, TypeId> type_infer_dict = {
|
||||
{{kNumberTypeComplex64, kNumberTypeComplex64}, kNumberTypeComplex64},
|
||||
{{kNumberTypeComplex64, kNumberTypeFloat32}, kNumberTypeComplex64},
|
||||
{{kNumberTypeFloat32, kNumberTypeComplex64}, kNumberTypeComplex64},
|
||||
{{kNumberTypeComplex128, kNumberTypeComplex128}, kNumberTypeComplex128},
|
||||
{{kNumberTypeComplex128, kNumberTypeFloat64}, kNumberTypeComplex128},
|
||||
{{kNumberTypeFloat64, kNumberTypeComplex128}, kNumberTypeComplex128}};
|
||||
std::pair<TypeId, TypeId> type_info(type_ids[0], type_ids[1]);
|
||||
auto iter = type_infer_dict.find(type_info);
|
||||
if (iter != type_infer_dict.end()) {
|
||||
return type_ids[0] == iter->second ? type_ptr[0] : type_ptr[1];
|
||||
}
|
||||
std::ostringstream buffer;
|
||||
buffer << "For primitive[" << prim_name << "], complex math binary op expecting Tensor";
|
||||
for (const auto &items : type_infer_dict) {
|
||||
buffer << "[" << TypeIdToString(items.first.first) << ", " << TypeIdToString(items.first.second) << "], ";
|
||||
}
|
||||
buffer << "but got Tensor[" << TypeIdToString(type_ids[0]) << ", " << TypeIdToString(type_ids[1]) << "]";
|
||||
MS_EXCEPTION(TypeError) << buffer.str();
|
||||
}
|
||||
// Deal with non-complex data type
|
||||
if (type_ids[0] != type_ids[1]) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the input arguments must have same data type, but got Tensor["
|
||||
<< TypeIdToString(type_ids[0]) << "] and Tensor[" << TypeIdToString(type_ids[1]) << "]";
|
||||
}
|
||||
(void)CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
|
||||
return types.begin()->second;
|
||||
}
|
||||
|
||||
ShapeVector CheckAndConvertUtils::CheckTensorShapeSame(const std::map<std::string, BaseShapePtr> &shapes,
|
||||
const std::vector<int64_t> &check_shape,
|
||||
const std::string &prim_name) {
|
||||
|
|
|
@ -289,6 +289,9 @@ class MS_CORE_API CheckAndConvertUtils {
|
|||
const std::vector<int64_t> &check_shape, const std::string &prim_name);
|
||||
static TypePtr CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypePtr> &check_list,
|
||||
const std::string &prim_name);
|
||||
// Return Tensor type
|
||||
static TypePtr CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types,
|
||||
const std::set<TypePtr> &check_list, const std::string &prim_name);
|
||||
static ShapeVector CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
|
||||
const std::string &prim_name);
|
||||
static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||
|
|
Loading…
Reference in New Issue