forked from mindspore-Ecosystem/mindspore
!19113 [AutoParallel] Optimize pipeline
Merge pull request !19113 from lichen/optimize_pipeline
This commit is contained in:
commit
b876d16847
|
@ -23,6 +23,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -41,10 +42,15 @@ class ReplaceApplicator : public AnfVisitor {
|
|||
}
|
||||
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) ||
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_input()) ||
|
||||
*(fg->switch_layer_input())) {
|
||||
return nullptr;
|
||||
}
|
||||
// Defer inlining in the case of pipeline.
|
||||
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (fg->stage() != -1 && stage_num > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
|
||||
if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||
return nullptr;
|
||||
|
@ -101,7 +107,12 @@ class InlinerBase : public AnfVisitor {
|
|||
auto &inputs = cnode->inputs();
|
||||
// G
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
|
||||
if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
|
||||
return nullptr;
|
||||
}
|
||||
// Defer inlining in the case of pipeline.
|
||||
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (fg->stage() != -1 && stage_num > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
// Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
|
||||
|
|
|
@ -442,6 +442,21 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
}
|
||||
}
|
||||
|
||||
ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map) {
|
||||
if (cnode->HasPrimalAttr(MICRO)) {
|
||||
return cnode->GetPrimalAttr(MICRO);
|
||||
}
|
||||
auto node_users = (*node_users_map)[cnode];
|
||||
for (auto &node_pair : node_users) {
|
||||
auto user_node = node_pair.first->cast<CNodePtr>();
|
||||
auto micro = Micro(user_node, node_users_map);
|
||||
if (micro) {
|
||||
return micro;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
|
||||
auto node_users_map = manager->node_users();
|
||||
for (auto &node : all_nodes) {
|
||||
|
@ -449,18 +464,16 @@ void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!cnode->HasPrimalAttr(MICRO)) {
|
||||
continue;
|
||||
}
|
||||
auto micro = cnode->GetPrimalAttr(MICRO);
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim && prim->HasAttr(PARAMETER_START)) {
|
||||
auto micro = Micro(cnode, &node_users_map);
|
||||
OperatorAttrs attrs_;
|
||||
auto op = CreatOpInstance(attrs_, prim->name(), "");
|
||||
auto value_node = NewValueNode(op);
|
||||
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
|
||||
new_prim->SetAttrs(prim->attrs());
|
||||
manager->SetEdge(cnode, 0, value_node);
|
||||
cnode->AddPrimalAttr(MICRO, micro);
|
||||
cnode->AddPrimalAttr(PARAMETER_START, micro);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,6 +60,7 @@ AnfNodePtr GetPreNode(const AnfNodePtr &node);
|
|||
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||
void SetStridedSliceStrategy(const AnfNodePtr &node);
|
||||
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||
ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map);
|
||||
void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
|
||||
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, size_t micro_size);
|
||||
} // namespace parallel
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -86,11 +87,47 @@ ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micr
|
|||
return MakeValue(micro);
|
||||
}
|
||||
|
||||
bool PipelineTransformer::NeedGrad(const CNodePtr &cnode) {
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (input->isa<Parameter>() && ParameterRequireGrad(input)) {
|
||||
return true;
|
||||
}
|
||||
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
auto load = input->cast<CNodePtr>();
|
||||
if (load->input(1)->isa<Parameter>() && ParameterRequireGrad(load->input(1))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph) {
|
||||
auto orders = graph->GetOrderedCnodes();
|
||||
for (auto &node : orders) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
return LabelParameterStart(sub_graph);
|
||||
}
|
||||
if (!IsPipelineCareNode(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (NeedGrad(cnode)) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
prim->AddAttr(PARAMETER_START, MakeValue(0));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PipelineTransformer::LabelMicroBatch() {
|
||||
if (!root_->has_flag(TRAINING)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(main_graph_);
|
||||
LabelParameterStart(main_graph_);
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_);
|
||||
auto node_user_map = manager_->node_users();
|
||||
auto node_users = node_user_map[virtual_dataset_];
|
||||
|
|
|
@ -75,6 +75,8 @@ class PipelineTransformer {
|
|||
std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node);
|
||||
std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node);
|
||||
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode, int tuple_index);
|
||||
void LabelParameterStart(const FuncGraphPtr &graph);
|
||||
bool NeedGrad(const CNodePtr &cnode);
|
||||
CNodePtr GraphOutNode(const AnfNodePtr &node, int tuple_index);
|
||||
bool IsPipelineCareNode(const CNodePtr &cnode);
|
||||
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
|
||||
|
|
Loading…
Reference in New Issue