forked from mindspore-Ecosystem/mindspore
Do not do replacement processing when the value of the parameter may be modified in the attach node of updatestate
This commit is contained in:
parent
2151b927ba
commit
dc1b5cf1bb
|
@ -236,6 +236,27 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
|
|||
return monad;
|
||||
}
|
||||
|
||||
bool MayModifyParameter(const AnfNodePtr &update_state, const AnfNodePtr &load) {
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
MS_EXCEPTION_IF_NULL(load);
|
||||
auto update_state_cnode = update_state->cast<CNodePtr>();
|
||||
auto load_cnode = load->cast<CNodePtr>();
|
||||
constexpr size_t attach_index = 2;
|
||||
auto attach = update_state_cnode->input(attach_index);
|
||||
if (!attach->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(attach->cast<CNodePtr>()->input(0))) {
|
||||
return true;
|
||||
}
|
||||
auto inputs = attach->cast<CNodePtr>()->inputs();
|
||||
bool exist_param_or_load = std::any_of(inputs.begin(), inputs.end(), [&load_cnode](const AnfNodePtr &input) {
|
||||
auto parameter = load_cnode->input(1);
|
||||
return input == load_cnode || input == parameter;
|
||||
});
|
||||
return exist_param_or_load;
|
||||
}
|
||||
|
||||
// Replace UpdateStates with U for first load.
|
||||
// Covert:
|
||||
// u1 = UpdateState(u, c)
|
||||
|
@ -258,6 +279,9 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
|
|||
if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
if (MayModifyParameter(update_state, load_node)) {
|
||||
continue;
|
||||
}
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
mgr->SetEdge(load_node, second_input_index, monad);
|
||||
|
|
Loading…
Reference in New Issue