Fix recompute update_state

This commit is contained in:
yujianfeng 2021-05-31 12:45:23 +08:00
parent e8ad10ea26
commit 53fb15339c
2 changed files with 11 additions and 1 deletions

View File

@ -102,6 +102,10 @@ void GetMaxSubGraph(const FuncGraphManagerPtr &mng, std::unordered_set<CNodePtr>
auto current_node = nodes_to_visit.front();
nodes_to_visit.pop();
recomputed_nodes->insert(current_node);
// No need to find nodes through side-effect dependency.
if (IsPrimitiveCNode(current_node, prim::kPrimUpdateState)) {
continue;
}
if (get_inputs) {
for (const auto &input : current_node->inputs()) {
MS_EXCEPTION_IF_NULL(input);
@ -327,7 +331,13 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
}
auto input_cnode = input->cast<CNodePtr>();
if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) {
new_inputs.emplace_back(input);
if (IsPrimitiveCNode(input_cnode, prim::kPrimUpdateState)) {
auto u = NewValueNode(kUMonad);
u->set_abstract(kUMonad->ToAbstract());
new_inputs.emplace_back(u);
} else {
new_inputs.emplace_back(input);
}
} else {
has_recomputed_inputs = true;
new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes,