forked from mindspore-Ecosystem/mindspore
run static_analysis with broaded args if all inputs are tensor
This commit is contained in:
parent
dcd9d18411
commit
5574687c63
|
@ -51,6 +51,19 @@ bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
|
|||
}
|
||||
return fg == parent;
|
||||
}
|
||||
|
||||
bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
|
||||
if (abs_base->isa<AbstractTensor>()) {
|
||||
return true;
|
||||
} else if (abs_base->isa<AbstractSequeue>()) {
|
||||
const auto &abs_seq = abs_base->cast<AbstractSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_seq);
|
||||
const auto &elements = abs_seq->elements();
|
||||
return std::all_of(elements.cbegin(), elements.cend(), [](const auto &v) { return CheckAbstractTensor(v); });
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
|
||||
|
@ -586,7 +599,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
return std::make_pair(joined_argvals, joined_eval_result->abstract());
|
||||
} else {
|
||||
bool all_args_tensor = std::all_of(broaded_argvals.cbegin(), broaded_argvals.cend(),
|
||||
[](const AbstractBasePtr &v) { return v->isa<AbstractTensor>(); });
|
||||
[](const AbstractBasePtr &v) { return CheckAbstractTensor(v); });
|
||||
if (all_args_tensor) {
|
||||
ConfigPtrList args_conf_list;
|
||||
(void)std::transform(broaded_argvals.cbegin(), broaded_argvals.cend(), std ::back_inserter(args_conf_list),
|
||||
|
|
Loading…
Reference in New Issue