forked from mindspore-Ecosystem/mindspore
Fix recompute update_state
This commit is contained in:
parent
e8ad10ea26
commit
53fb15339c
|
@ -102,6 +102,10 @@ void GetMaxSubGraph(const FuncGraphManagerPtr &mng, std::unordered_set<CNodePtr>
|
|||
auto current_node = nodes_to_visit.front();
|
||||
nodes_to_visit.pop();
|
||||
recomputed_nodes->insert(current_node);
|
||||
// No need to find nodes through side-effect dependency.
|
||||
if (IsPrimitiveCNode(current_node, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
if (get_inputs) {
|
||||
for (const auto &input : current_node->inputs()) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -327,7 +331,13 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
|
|||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) {
|
||||
new_inputs.emplace_back(input);
|
||||
if (IsPrimitiveCNode(input_cnode, prim::kPrimUpdateState)) {
|
||||
auto u = NewValueNode(kUMonad);
|
||||
u->set_abstract(kUMonad->ToAbstract());
|
||||
new_inputs.emplace_back(u);
|
||||
} else {
|
||||
new_inputs.emplace_back(input);
|
||||
}
|
||||
} else {
|
||||
has_recomputed_inputs = true;
|
||||
new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes,
|
||||
|
|
Loading…
Reference in New Issue