!19113 [AutoParallel] Optimize pipeline

Merge pull request !19113 from lichen/optimize_pipeline
This commit is contained in:
i-robot 2021-07-02 01:31:55 +00:00 committed by Gitee
commit b876d16847
5 changed files with 70 additions and 6 deletions

View File

@ -23,6 +23,7 @@
#include <unordered_map>
#include "frontend/optimizer/irpass.h"
#include "frontend/parallel/context.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
@ -41,10 +42,15 @@ class ReplaceApplicator : public AnfVisitor {
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) ||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_input()) ||
*(fg->switch_layer_input())) {
return nullptr;
}
// Defer inlining in the case of pipeline.
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
if (fg->stage() != -1 && stage_num > 1) {
return nullptr;
}
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
@ -101,7 +107,12 @@ class InlinerBase : public AnfVisitor {
auto &inputs = cnode->inputs();
// G
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
// Defer inlining in the case of pipeline.
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
if (fg->stage() != -1 && stage_num > 1) {
return nullptr;
}
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.

View File

@ -442,6 +442,21 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
}
}
ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map) {
if (cnode->HasPrimalAttr(MICRO)) {
return cnode->GetPrimalAttr(MICRO);
}
auto node_users = (*node_users_map)[cnode];
for (auto &node_pair : node_users) {
auto user_node = node_pair.first->cast<CNodePtr>();
auto micro = Micro(user_node, node_users_map);
if (micro) {
return micro;
}
}
return nullptr;
}
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
auto node_users_map = manager->node_users();
for (auto &node : all_nodes) {
@ -449,18 +464,16 @@ void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!cnode->HasPrimalAttr(MICRO)) {
continue;
}
auto micro = cnode->GetPrimalAttr(MICRO);
auto prim = GetCNodePrimitive(node);
if (prim && prim->HasAttr(PARAMETER_START)) {
auto micro = Micro(cnode, &node_users_map);
OperatorAttrs attrs_;
auto op = CreatOpInstance(attrs_, prim->name(), "");
auto value_node = NewValueNode(op);
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
new_prim->SetAttrs(prim->attrs());
manager->SetEdge(cnode, 0, value_node);
cnode->AddPrimalAttr(MICRO, micro);
cnode->AddPrimalAttr(PARAMETER_START, micro);
}
}

View File

@ -60,6 +60,7 @@ AnfNodePtr GetPreNode(const AnfNodePtr &node);
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
void SetStridedSliceStrategy(const AnfNodePtr &node);
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map);
void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, size_t micro_size);
} // namespace parallel

View File

@ -31,6 +31,7 @@
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "ir/anf.h"
#include "ir/graph_utils.h"
#include "base/core_ops.h"
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
@ -86,11 +87,47 @@ ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micr
return MakeValue(micro);
}
bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) {
for (auto &input : cnode->inputs()) {
if (input->isa<Parameter>() && ParameterRequireGrad(input)) {
return true;
}
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
auto load = input->cast<CNodePtr>();
if (load->input(1)->isa<Parameter>() && ParameterRequireGrad(load->input(1))) {
return true;
}
}
}
return false;
}
void PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph) {
auto orders = graph->GetOrderedCnodes();
for (auto &node : orders) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsValueNode<FuncGraph>(cnode->input(0))) {
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
return LabelParameterStart(sub_graph);
}
if (!IsPipelineCareNode(cnode)) {
continue;
}
if (NeedGrad(cnode)) {
auto prim = GetCNodePrimitive(cnode);
prim->AddAttr(PARAMETER_START, MakeValue(0));
return;
}
}
}
void PipelineTransformer::LabelMicroBatch() {
if (!root_->has_flag(TRAINING)) {
return;
}
MS_EXCEPTION_IF_NULL(main_graph_);
LabelParameterStart(main_graph_);
MS_EXCEPTION_IF_NULL(virtual_dataset_);
auto node_user_map = manager_->node_users();
auto node_users = node_user_map[virtual_dataset_];

View File

@ -75,6 +75,8 @@ class PipelineTransformer {
std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode, int tuple_index);
void LabelParameterStart(const FuncGraphPtr &graph);
bool NeedGrad(const CNodePtr &cnode);
CNodePtr GraphOutNode(const AnfNodePtr &node, int tuple_index);
bool IsPipelineCareNode(const CNodePtr &cnode);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();