!18522 No need to recompute load node

Merge pull request !18522 from YuJianfeng/recompute1
This commit is contained in:
i-robot 2021-06-22 14:00:05 +00:00 committed by Gitee
commit e78380965e
3 changed files with 36 additions and 18 deletions

View File

@ -130,7 +130,8 @@ class IncorporateGetitem : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}
@ -205,7 +206,8 @@ class IncorporateGetitemDepend : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}

View File

@ -45,8 +45,8 @@ class SetCellOutputNoRecompute : public AnfVisitor {
std::unordered_set<CNodePtr> real_outputs;
GetRealOutputNodes(output, &real_outputs);
for (const auto &real_output : real_outputs) {
auto prim = GetValueNode<PrimitivePtr>(real_output->input(0));
prim->set_attr(kAttrRecompute, MakeValue(false));
// Set the attr of cnode in case of shared primitives.
real_output->AddAttr(kAttrRecompute, MakeValue(false));
}
}
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);

View File

@ -33,7 +33,8 @@ namespace {
constexpr auto kGradientsFlag = "Gradients";
bool CanNotRecomputed(const CNodePtr &node) {
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask};
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask,
prim::kPrimLoad, prim::kPrimTupleGetItem};
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
@ -56,16 +57,26 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
return full_name_with_scope.find(kAttrRecompute) == 0;
}
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {
ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return false;
return nullptr;
}
auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute);
return cnode->GetAttr(kAttrRecompute);
}
bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node) {
auto cnode_recompute_val = GetRecomputeCNodeAttr(node);
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && !GetValue<bool>(cnode_recompute_val);
}
bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) {
auto cnode_recompute_val = GetRecomputeCNodeAttr(node);
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
}
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); }
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && IsSetRecomputeCNodeAttr(node); }
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
const std::vector<CNodePtr> &cnodes) {
@ -213,12 +224,16 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool>
return false;
}
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map);
})) {
for (size_t i = 0; i < inputs.size(); ++i) {
// For the pipeline split case, the forward pass may depend on the backward pass.
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) {
continue;
}
if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) {
has_grad_inputs_map->insert(std::make_pair(node, true));
return true;
}
}
has_grad_inputs_map->insert(std::make_pair(node, false));
return false;
}
@ -265,6 +280,10 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
for (const auto &node : origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(node);
// The node may be set the non-recomputed before such as the cell outputs.
if (IsSetNoRecomputeCNodeAttr(node)) {
continue;
}
if (IsBpropNode(node)) {
continue;
}
@ -272,9 +291,6 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
if (CanNotRecomputed(node)) {
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
continue;
}
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
continue;
}
@ -293,7 +309,7 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
if ((SetRecomputedScope(cnode) && prim_recompute_val != 0) || prim_recompute_val == 1) {
cnode->AddAttr(kAttrRecompute, MakeValue(true));
}
if (!HasRecomputeCNodeAttr(node)) {
if (!IsSetRecomputeCNodeAttr(node)) {
continue;
}
// Set attr for the tuple_getitem outputs.