Do not do replacement processing when the value of the parameter may be modified in the attach node of updatestate

This commit is contained in:
Margaret_wangrui 2021-08-14 12:27:09 +08:00
parent 2151b927ba
commit dc1b5cf1bb
1 changed files with 24 additions and 0 deletions

View File

@ -236,6 +236,27 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
return monad;
}
bool MayModifyParameter(const AnfNodePtr &update_state, const AnfNodePtr &load) {
MS_EXCEPTION_IF_NULL(update_state);
MS_EXCEPTION_IF_NULL(load);
auto update_state_cnode = update_state->cast<CNodePtr>();
auto load_cnode = load->cast<CNodePtr>();
constexpr size_t attach_index = 2;
auto attach = update_state_cnode->input(attach_index);
if (!attach->isa<CNode>()) {
return false;
}
if (IsValueNode<FuncGraph>(attach->cast<CNodePtr>()->input(0))) {
return true;
}
auto inputs = attach->cast<CNodePtr>()->inputs();
bool exist_param_or_load = std::any_of(inputs.begin(), inputs.end(), [&load_cnode](const AnfNodePtr &input) {
auto parameter = load_cnode->input(1);
return input == load_cnode || input == parameter;
});
return exist_param_or_load;
}
// Replace UpdateStates with U for first load.
// Covert:
// u1 = UpdateState(u, c)
@ -258,6 +279,9 @@ bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) {
continue;
}
if (MayModifyParameter(update_state, load_node)) {
continue;
}
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
mgr->SetEdge(load_node, second_input_index, monad);