!19797 fix getfuncgraphspecializer inner error

Merge pull request !19797 from xychow/fix-getfuncgraphspecializer-inner-error
This commit is contained in:
i-robot 2021-07-09 10:58:49 +00:00 committed by Gitee
commit b71271b9bc
1 changed files with 14 additions and 1 deletions

View File

@ -51,6 +51,19 @@ bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
}
return fg == parent;
}
bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
if (abs_base->isa<AbstractTensor>()) {
return true;
} else if (abs_base->isa<AbstractSequeue>()) {
const auto &abs_seq = abs_base->cast<AbstractSequeuePtr>();
MS_EXCEPTION_IF_NULL(abs_seq);
const auto &elements = abs_seq->elements();
return std::all_of(elements.cbegin(), elements.cend(), [](const auto &v) { return CheckAbstractTensor(v); });
} else {
return false;
}
}
} // namespace
FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
@ -586,7 +599,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
return std::make_pair(joined_argvals, joined_eval_result->abstract());
} else {
bool all_args_tensor = std::all_of(broaded_argvals.cbegin(), broaded_argvals.cend(),
[](const AbstractBasePtr &v) { return v->isa<AbstractTensor>(); });
[](const AbstractBasePtr &v) { return CheckAbstractTensor(v); });
if (all_args_tensor) {
ConfigPtrList args_conf_list;
(void)std::transform(broaded_argvals.cbegin(), broaded_argvals.cend(), std ::back_inserter(args_conf_list),