run static_analysis with broaded args if all inputs are tensor

This commit is contained in:
zhousiyi 2021-07-09 03:26:26 +00:00
parent dcd9d18411
commit 5574687c63
1 changed files with 14 additions and 1 deletions

View File

@ -51,6 +51,19 @@ bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
}
return fg == parent;
}
bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
if (abs_base->isa<AbstractTensor>()) {
return true;
} else if (abs_base->isa<AbstractSequeue>()) {
const auto &abs_seq = abs_base->cast<AbstractSequeuePtr>();
MS_EXCEPTION_IF_NULL(abs_seq);
const auto &elements = abs_seq->elements();
return std::all_of(elements.cbegin(), elements.cend(), [](const auto &v) { return CheckAbstractTensor(v); });
} else {
return false;
}
}
} // namespace
FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
@ -586,7 +599,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
return std::make_pair(joined_argvals, joined_eval_result->abstract());
} else {
bool all_args_tensor = std::all_of(broaded_argvals.cbegin(), broaded_argvals.cend(),
[](const AbstractBasePtr &v) { return v->isa<AbstractTensor>(); });
[](const AbstractBasePtr &v) { return CheckAbstractTensor(v); });
if (all_args_tensor) {
ConfigPtrList args_conf_list;
(void)std::transform(broaded_argvals.cbegin(), broaded_argvals.cend(), std ::back_inserter(args_conf_list),