forked from mindspore-Ecosystem/mindspore
!21861 check switch/switchlayer in EliminateUpdateStateForPureNode
Merge pull request !21861 from huangbingjian/pure_node_eliminate
This commit is contained in:
commit
62a780f20f
|
@ -85,18 +85,39 @@ bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_n
|
|||
(first_user == second_node && second_user == first_node);
|
||||
}
|
||||
|
||||
// Determine whether there is a monad in the inputs of the node.
|
||||
bool CheckHasMonadInput(const CNodePtr &cnode) {
|
||||
// If the last input is a monad, means the attach node has side-effect and
|
||||
// we should keep UpdateState; otherwise, we will remove the UpdateState.
|
||||
if (cnode->size() > 1 && HasAbstractMonad(cnode->inputs().back())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check the inputs of Call/Switch/SwitchLayer.
|
||||
auto first_input_node = cnode->input(kFirstInputIndex);
|
||||
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
|
||||
for (auto &input : first_input_node->cast<CNodePtr>()->inputs()) {
|
||||
if (HasAbstractMonad(input)) {
|
||||
return true;
|
||||
}
|
||||
auto input_cnode = dyn_cast<CNode>(input);
|
||||
if (input_cnode != nullptr && input_cnode->size() > 1 && HasAbstractMonad(input_cnode->inputs().back())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) {
|
||||
auto cnode = dyn_cast<CNode>(attach);
|
||||
if (cnode == nullptr) {
|
||||
// Skip value node or parameter.
|
||||
return nullptr;
|
||||
}
|
||||
if (cnode->size() > 1) {
|
||||
// If the last input is a monad, means the attach node has side-effect and
|
||||
// we should keep UpdateState; otherwise, we will remove the UpdateState.
|
||||
if (HasAbstractMonad(cnode->inputs().back())) {
|
||||
return nullptr;
|
||||
}
|
||||
if (CheckHasMonadInput(cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Remove UpdateState by replace it with its input monad.
|
||||
|
|
Loading…
Reference in New Issue