forked from mindspore-Ecosystem/mindspore
!25641 fix subclass check error
Merge pull request !25641 from lianliguang/fix-bug-error
This commit is contained in:
commit
010cc7a435
|
@ -540,8 +540,9 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa
|
|||
<< res_spec->ToString();
|
||||
} else if (res_spec->isa<AbstractTuple>() &&
|
||||
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
|
||||
MS_LOG(EXCEPTION) << "Custom primitive[" << prim->ToString() << "]'s attribute[output_num]:" << output_num
|
||||
<< " not matches the infer result " << res_spec->ToString();
|
||||
MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
|
||||
<< "]'s attribute[output_num]:" << output_num << " not matches the infer result "
|
||||
<< res_spec->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -304,6 +304,11 @@ ValuePtr L2NormalizeAttrConversion(ValuePtr attr) {
|
|||
|
||||
std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis", L2NormalizeAttrConversion}}},
|
||||
{"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
|
||||
inline bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types) {
|
||||
return std::any_of(template_types.begin(), template_types.end(), [&check_type](const TypePtr &accept) -> bool {
|
||||
return IsIdentidityOrSubclass(check_type, accept);
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
|
||||
|
@ -457,6 +462,7 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
|
|||
if (types.empty()) {
|
||||
MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
|
||||
}
|
||||
// Check Input type is tensor type
|
||||
for (const auto &item : types) {
|
||||
auto type = item.second;
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
@ -483,12 +489,12 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
|
|||
}
|
||||
}
|
||||
auto check_type = _CheckTypeSame(types, prim_name, false);
|
||||
std::string input_names = "";
|
||||
std::string input_names;
|
||||
for (const auto &item : types) {
|
||||
(void)input_names.append(item.first);
|
||||
(void)input_names.append(", ");
|
||||
}
|
||||
return CheckSubClass(input_names, check_type, check_list, prim_name);
|
||||
return CheckTensorSubClass(input_names, check_type, check_list, prim_name);
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||
|
@ -509,7 +515,7 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
|
|||
}
|
||||
}
|
||||
}
|
||||
return CheckSubClass(type_name, type, check_list, prim_name);
|
||||
return CheckTensorSubClass(type_name, element, check_list, prim_name);
|
||||
}
|
||||
|
||||
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
|
||||
|
@ -545,38 +551,50 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
|
|||
return tensor_value;
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &template_types, const string &prim_name) {
|
||||
if (CheckType(type, template_types)) {
|
||||
return type;
|
||||
}
|
||||
std::ostringstream buffer;
|
||||
buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of {";
|
||||
for (const auto &item : template_types) {
|
||||
if (item->isa<TensorType>()) {
|
||||
buffer << item->ToString();
|
||||
continue;
|
||||
}
|
||||
buffer << " Tensor[" << item->ToString() << "],";
|
||||
}
|
||||
buffer << "}, but got " << type->ToString();
|
||||
buffer << ".";
|
||||
MS_EXCEPTION(TypeError) << buffer.str();
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &template_types, const std::string &prim_name) {
|
||||
auto check_type = type;
|
||||
bool ok = std::any_of(template_types.begin(), template_types.end(), [check_type](const TypePtr &accept) -> bool {
|
||||
return IsIdentidityOrSubclass(check_type, accept);
|
||||
});
|
||||
if (ok) {
|
||||
return check_type;
|
||||
if (CheckType(type, template_types)) {
|
||||
return type;
|
||||
}
|
||||
if (type->isa<TensorType>()) {
|
||||
auto tensor_type = type->cast<TensorTypePtr>();
|
||||
check_type = tensor_type->element();
|
||||
}
|
||||
ok = std::any_of(template_types.begin(), template_types.end(),
|
||||
[check_type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(check_type, accept); });
|
||||
if (ok) {
|
||||
return check_type;
|
||||
} else {
|
||||
std::string type_str = type->ToString();
|
||||
std::ostringstream buffer;
|
||||
buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of ";
|
||||
buffer << GetErrorTypeString(template_types, type) << ", but got " << type->ToString();
|
||||
buffer << ".";
|
||||
MS_EXCEPTION(TypeError) << buffer.str();
|
||||
std::ostringstream buffer;
|
||||
buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of {";
|
||||
for (const auto &item : template_types) {
|
||||
buffer << " " << item->ToString() << ",";
|
||||
}
|
||||
buffer << "}, but got " << type->ToString();
|
||||
buffer << ".";
|
||||
MS_EXCEPTION(TypeError) << buffer.str();
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
|
||||
const std::set<TypePtr> &valid_values,
|
||||
const std::string &prim_name, const bool allow_mix) {
|
||||
auto arg_ = _CheckTypeSame(args, prim_name, allow_mix);
|
||||
return CheckTypeValid(args.begin()->first, arg_, valid_values, prim_name);
|
||||
std::string input_names;
|
||||
for (const auto &item : args) {
|
||||
(void)input_names.append(item.first);
|
||||
(void)input_names.append(", ");
|
||||
}
|
||||
return CheckMixSubClass(input_names, arg_, valid_values, prim_name);
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
|
||||
|
@ -607,13 +625,12 @@ TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr
|
|||
auto tensor_type = type->cast<TensorTypePtr>();
|
||||
auto element = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
if (!allow_mix) {
|
||||
return_type = element;
|
||||
} else {
|
||||
return_type = tensor_type;
|
||||
}
|
||||
return_type = element;
|
||||
(void)types_id.emplace(element->type_id());
|
||||
} else {
|
||||
if (return_type->isa<TensorType>()) {
|
||||
return_type = type;
|
||||
}
|
||||
(void)types_id.emplace(type->type_id());
|
||||
}
|
||||
if (types_id.size() > 1) {
|
||||
|
@ -774,42 +791,25 @@ bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_l
|
|||
return false;
|
||||
}
|
||||
|
||||
std::string CheckAndConvertUtils::GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type) {
|
||||
TypePtr CheckAndConvertUtils::CheckMixSubClass(const string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &template_types, const string &prim_name) {
|
||||
if (CheckType(type, template_types)) {
|
||||
return type;
|
||||
}
|
||||
std::ostringstream buffer;
|
||||
buffer << "{";
|
||||
// got tensor type list
|
||||
for (const auto &item : check_list) {
|
||||
buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of {";
|
||||
for (const auto &item : template_types) {
|
||||
if (item->isa<TensorType>()) {
|
||||
buffer << item->ToString();
|
||||
buffer << ", ";
|
||||
continue;
|
||||
}
|
||||
buffer << "Tensor[" << item->ToString() << "], ";
|
||||
buffer << " Tensor[" << item->ToString() << "],";
|
||||
}
|
||||
if (check_type->isa<TensorType>()) {
|
||||
buffer << "}";
|
||||
return buffer.str();
|
||||
for (const auto &item : template_types) {
|
||||
buffer << " " << item->ToString() << "],";
|
||||
}
|
||||
// got python type
|
||||
std::set<std::string> type_string;
|
||||
for (const auto &item : check_list) {
|
||||
if (item->isa<Float>()) {
|
||||
type_string.emplace("Float");
|
||||
}
|
||||
if (item->isa<Int>()) {
|
||||
type_string.emplace("Int");
|
||||
}
|
||||
if (item->isa<Bool>()) {
|
||||
type_string.emplace("Bool");
|
||||
}
|
||||
if (item->isa<UInt>()) {
|
||||
type_string.emplace("UInt");
|
||||
}
|
||||
}
|
||||
for (const auto &item : type_string) {
|
||||
buffer << item << ",";
|
||||
}
|
||||
buffer << "}";
|
||||
return buffer.str();
|
||||
buffer << "}, but got " << type->ToString();
|
||||
buffer << ".";
|
||||
MS_EXCEPTION(TypeError) << buffer.str();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -320,7 +320,10 @@ class CheckAndConvertUtils {
|
|||
private:
|
||||
static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
|
||||
const bool allow_mix);
|
||||
static std::string GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type);
|
||||
static TypePtr CheckTensorSubClass(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &template_types, const std::string &prim_name);
|
||||
static TypePtr CheckMixSubClass(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &template_types, const std::string &prim_name);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
|
||||
|
|
Loading…
Reference in New Issue