forked from mindspore-Ecosystem/mindspore
check call/switch/switchlayer in EliminateUpdateStateForPureNode
This commit is contained in:
parent
06e86e34de
commit
c3bb0b91b4
|
@ -85,19 +85,40 @@ bool OnlyUsedByTwoNode(const AnfNodePtr &be_used_node, const AnfNodePtr &first_n
|
||||||
(first_user == second_node && second_user == first_node);
|
(first_user == second_node && second_user == first_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine whether there is a monad in the inputs of the node.
|
||||||
|
bool CheckHasMonadInput(const CNodePtr &cnode) {
|
||||||
|
// If the last input is a monad, means the attach node has side-effect and
|
||||||
|
// we should keep UpdateState; otherwise, we will remove the UpdateState.
|
||||||
|
if (cnode->size() > 1 && HasAbstractMonad(cnode->inputs().back())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the inputs of Call/Switch/SwitchLayer.
|
||||||
|
auto first_input_node = cnode->input(kFirstInputIndex);
|
||||||
|
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
|
||||||
|
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
|
||||||
|
for (auto &input : first_input_node->cast<CNodePtr>()->inputs()) {
|
||||||
|
if (HasAbstractMonad(input)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto input_cnode = dyn_cast<CNode>(input);
|
||||||
|
if (input_cnode != nullptr && input_cnode->size() > 1 && HasAbstractMonad(input_cnode->inputs().back())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) {
|
AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) {
|
||||||
auto cnode = dyn_cast<CNode>(attach);
|
auto cnode = dyn_cast<CNode>(attach);
|
||||||
if (cnode == nullptr) {
|
if (cnode == nullptr) {
|
||||||
// Skip value node or parameter.
|
// Skip value node or parameter.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (cnode->size() > 1) {
|
if (CheckHasMonadInput(cnode)) {
|
||||||
// If the last input is a monad, means the attach node has side-effect and
|
|
||||||
// we should keep UpdateState; otherwise, we will remove the UpdateState.
|
|
||||||
if (HasAbstractMonad(cnode->inputs().back())) {
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Remove UpdateState by replace it with its input monad.
|
// Remove UpdateState by replace it with its input monad.
|
||||||
return update_state->input(kInputIndex);
|
return update_state->input(kInputIndex);
|
||||||
|
|
Loading…
Reference in New Issue