!48683 Fix closure check rules

Merge pull request !48683 from chenfei_mindspore/fix-closure-check
This commit is contained in:
i-robot 2023-02-10 09:34:13 +00:00 committed by Gitee
commit 1aafb03fd1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 19 additions and 1 deletions

View File

@ -977,6 +977,18 @@ bool SupportInlinePartial(const AnfNodePtr &input0) {
return false;
}
bool HasAbstractFunction(const AbstractBasePtr &abs) {
if (abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractSparseTensor>()) {
auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
if (abs_seq->dynamic_len()) {
return HasAbstractFunction(abs_seq->dynamic_len_element_abs());
}
return std::any_of(abs_seq->elements().cbegin(), abs_seq->elements().cend(), HasAbstractFunction);
}
// if abs it not AbstractSequence.
return abs->isa<abstract::AbstractFunction>();
}
bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
for (const auto &node : all_nodes) {
if (node == nullptr || !node->isa<CNode>()) {
@ -1019,7 +1031,13 @@ bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
auto input0 = cnode->input(0);
if (IsPrimitiveCNode(input0, prim::kPrimSwitch) || IsPrimitiveCNode(input0, prim::kPrimSwitchLayer) ||
IsValueNode<FuncGraph>(input0)) {
continue;
auto func_graphs = abstract::GetFuncGraphsFromAbs(input0);
auto graph_has_function_output = [](const FuncGraphPtr &fg) {
return HasAbstractFunction(fg->output()->abstract());
};
if (std::all_of(func_graphs.cbegin(), func_graphs.cend(), std::not_fn(graph_has_function_output))) {
continue;
}
}
if (SupportInlinePartial(input0)) {
continue;