forked from mindspore-Ecosystem/mindspore
Recompute the tuple_getitem which is in the bprop function
This commit is contained in:
parent
2510e9f3d3
commit
db9a1f4f24
|
@ -76,7 +76,10 @@ bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) {
|
|||
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
|
||||
}
|
||||
|
||||
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && IsSetRecomputeCNodeAttr(node); }
|
||||
bool IsCandidateRecomputedNode(const CNodePtr &node) {
|
||||
// The tuple_getitem in the bprop function should also be recomputed.
|
||||
return (!IsBpropNode(node) || IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) && IsSetRecomputeCNodeAttr(node);
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
|
||||
const std::vector<CNodePtr> &cnodes) {
|
||||
|
@ -176,17 +179,8 @@ void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
|
|||
for (const auto &node_index_set : output_set_iter->second) {
|
||||
auto output_node = node_index_set.first;
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
if (!IsBpropNode(output_node)) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_getitem_users = node_users.find(output_node);
|
||||
if (tuple_getitem_users == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto &user : tuple_getitem_users->second) {
|
||||
target_nodes->insert(user.first->cast<CNodePtr>());
|
||||
}
|
||||
// The tuple_getitem to be recomputed can be in the bprop function.
|
||||
if (!IsBpropNode(output_node) || IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
target_nodes->insert(output_node->cast<CNodePtr>());
|
||||
|
|
Loading…
Reference in New Issue