!21861 check switch/switchlayer in EliminateUpdateStateForPureNode

Merge pull request !21861 from huangbingjian/pure_node_eliminate
This commit is contained in:
i-robot 2021-08-17 02:03:35 +00:00 committed by Gitee
commit 62a780f20f
1 changed files with 27 additions and 6 deletions

View File

@ -85,18 +85,39 @@ bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_n
(first_user == second_node && second_user == first_node);
}
// Determine whether there is a monad in the inputs of the node.
bool CheckHasMonadInput(const CNodePtr &cnode) {
// If the last input is a monad, means the attach node has side-effect and
// we should keep UpdateState; otherwise, we will remove the UpdateState.
if (cnode->size() > 1 && HasAbstractMonad(cnode->inputs().back())) {
return true;
}
// Check the inputs of Call/Switch/SwitchLayer.
auto first_input_node = cnode->input(kFirstInputIndex);
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
for (auto &input : first_input_node->cast<CNodePtr>()->inputs()) {
if (HasAbstractMonad(input)) {
return true;
}
auto input_cnode = dyn_cast<CNode>(input);
if (input_cnode != nullptr && input_cnode->size() > 1 && HasAbstractMonad(input_cnode->inputs().back())) {
return true;
}
}
}
return false;
}
AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) {
auto cnode = dyn_cast<CNode>(attach);
if (cnode == nullptr) {
// Skip value node or parameter.
return nullptr;
}
if (cnode->size() > 1) {
// If the last input is a monad, means the attach node has side-effect and
// we should keep UpdateState; otherwise, we will remove the UpdateState.
if (HasAbstractMonad(cnode->inputs().back())) {
return nullptr;
}
if (CheckHasMonadInput(cnode)) {
return nullptr;
}
// Remove UpdateState by replace it with its input monad.