Recompute the tuple_getitem which is in the bprop function

This commit is contained in:
yujianfeng 2021-12-17 10:49:27 +08:00
parent 2510e9f3d3
commit db9a1f4f24
1 changed files with 6 additions and 12 deletions

View File

@ -76,7 +76,10 @@ bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) {
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
}
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && IsSetRecomputeCNodeAttr(node); }
bool IsCandidateRecomputedNode(const CNodePtr &node) {
// The tuple_getitem in the bprop function should also be recomputed.
return (!IsBpropNode(node) || IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) && IsSetRecomputeCNodeAttr(node);
}
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
const std::vector<CNodePtr> &cnodes) {
@ -176,17 +179,8 @@ void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
for (const auto &node_index_set : output_set_iter->second) {
auto output_node = node_index_set.first;
MS_EXCEPTION_IF_NULL(output_node);
if (!IsBpropNode(output_node)) {
continue;
}
if (IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem)) {
auto tuple_getitem_users = node_users.find(output_node);
if (tuple_getitem_users == node_users.end()) {
continue;
}
for (const auto &user : tuple_getitem_users->second) {
target_nodes->insert(user.first->cast<CNodePtr>());
}
// The tuple_getitem to be recomputed can be in the bprop function.
if (!IsBpropNode(output_node) || IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem)) {
continue;
}
target_nodes->insert(output_node->cast<CNodePtr>());