!18522 No need to recompute load node
Merge pull request !18522 from YuJianfeng/recompute1
This commit is contained in:
commit
e78380965e
|
@ -130,7 +130,8 @@ class IncorporateGetitem : public AnfVisitor {
|
|||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
|
||||
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
||||
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
|
||||
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -205,7 +206,8 @@ class IncorporateGetitemDepend : public AnfVisitor {
|
|||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
|
||||
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
||||
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
|
||||
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
|||
std::unordered_set<CNodePtr> real_outputs;
|
||||
GetRealOutputNodes(output, &real_outputs);
|
||||
for (const auto &real_output : real_outputs) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(real_output->input(0));
|
||||
prim->set_attr(kAttrRecompute, MakeValue(false));
|
||||
// Set the attr of cnode in case of shared primitives.
|
||||
real_output->AddAttr(kAttrRecompute, MakeValue(false));
|
||||
}
|
||||
}
|
||||
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
|
||||
|
|
|
@ -33,7 +33,8 @@ namespace {
|
|||
constexpr auto kGradientsFlag = "Gradients";
|
||||
|
||||
bool CanNotRecomputed(const CNodePtr &node) {
|
||||
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask};
|
||||
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask,
|
||||
prim::kPrimLoad, prim::kPrimTupleGetItem};
|
||||
|
||||
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
|
@ -56,16 +57,26 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
|
|||
return full_name_with_scope.find(kAttrRecompute) == 0;
|
||||
}
|
||||
|
||||
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {
|
||||
ValuePtr GetRecomputeCNodeAttr(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode_recompute_val = cnode->GetAttr(kAttrRecompute);
|
||||
return cnode->GetAttr(kAttrRecompute);
|
||||
}
|
||||
|
||||
bool IsSetNoRecomputeCNodeAttr(const AnfNodePtr &node) {
|
||||
auto cnode_recompute_val = GetRecomputeCNodeAttr(node);
|
||||
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && !GetValue<bool>(cnode_recompute_val);
|
||||
}
|
||||
|
||||
bool IsSetRecomputeCNodeAttr(const AnfNodePtr &node) {
|
||||
auto cnode_recompute_val = GetRecomputeCNodeAttr(node);
|
||||
return cnode_recompute_val != nullptr && cnode_recompute_val->isa<BoolImm>() && GetValue<bool>(cnode_recompute_val);
|
||||
}
|
||||
|
||||
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && HasRecomputeCNodeAttr(node); }
|
||||
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsBpropNode(node) && IsSetRecomputeCNodeAttr(node); }
|
||||
|
||||
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
|
||||
const std::vector<CNodePtr> &cnodes) {
|
||||
|
@ -213,11 +224,15 @@ bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool>
|
|||
return false;
|
||||
}
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
|
||||
return IsBpropNode(input) || HasGradInputs(input, has_grad_inputs_map);
|
||||
})) {
|
||||
has_grad_inputs_map->insert(std::make_pair(node, true));
|
||||
return true;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// For the pipeline split case, the forward pass may depend on the backward pass.
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && i == kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
if (IsBpropNode(inputs[i]) || HasGradInputs(inputs[i], has_grad_inputs_map)) {
|
||||
has_grad_inputs_map->insert(std::make_pair(node, true));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
has_grad_inputs_map->insert(std::make_pair(node, false));
|
||||
return false;
|
||||
|
@ -265,6 +280,10 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
|
|||
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
|
||||
for (const auto &node : origin_nodes_topological) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// The node may be set the non-recomputed before such as the cell outputs.
|
||||
if (IsSetNoRecomputeCNodeAttr(node)) {
|
||||
continue;
|
||||
}
|
||||
if (IsBpropNode(node)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -272,9 +291,6 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
|
|||
if (CanNotRecomputed(node)) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -293,7 +309,7 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
|
|||
if ((SetRecomputedScope(cnode) && prim_recompute_val != 0) || prim_recompute_val == 1) {
|
||||
cnode->AddAttr(kAttrRecompute, MakeValue(true));
|
||||
}
|
||||
if (!HasRecomputeCNodeAttr(node)) {
|
||||
if (!IsSetRecomputeCNodeAttr(node)) {
|
||||
continue;
|
||||
}
|
||||
// Set attr for the tuple_getitem outputs.
|
||||
|
|
Loading…
Reference in New Issue