forked from mindspore-Ecosystem/mindspore
optimizer_pipline_split
This commit is contained in:
parent
1cbef74372
commit
9595502278
|
@ -69,6 +69,8 @@ namespace mindspore {
|
|||
namespace session {
|
||||
const size_t kInvalidIndex = SIZE_MAX;
|
||||
constexpr size_t kReturnDataIndex = 1;
|
||||
constexpr char SR_TAG[] = "sr_tag";
|
||||
constexpr char BACKWARD[] = "backward";
|
||||
namespace {
|
||||
void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") {
|
||||
MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
|
||||
|
@ -460,6 +462,90 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
|
|||
return graph_id;
|
||||
}
|
||||
|
||||
bool IsBackward(const CNodePtr &cnode) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
return prim->HasAttr(BACKWARD);
|
||||
}
|
||||
|
||||
// compare the value of send/recv sr_tag
|
||||
bool comp(const CNodePtr &node1, const CNodePtr &node2) {
|
||||
auto prim1 = GetValueNode<PrimitivePtr>(node1->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim1);
|
||||
auto prim2 = GetValueNode<PrimitivePtr>(node1->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim2);
|
||||
auto sr_tag_value1 = prim1->GetAttr(SR_TAG);
|
||||
MS_EXCEPTION_IF_NULL(sr_tag_value1);
|
||||
auto sr_tag_value2 = prim2->GetAttr(SR_TAG);
|
||||
MS_EXCEPTION_IF_NULL(sr_tag_value2);
|
||||
auto sr_tag1 = GetValue<int64_t>(sr_tag_value1);
|
||||
auto sr_tag2 = GetValue<int64_t>(sr_tag_value2);
|
||||
return sr_tag1 < sr_tag2;
|
||||
}
|
||||
|
||||
// Reorder the execution order of send
|
||||
void ReorderSend(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
|
||||
auto last_node = op_v.back();
|
||||
for (auto &node : op_v) {
|
||||
if (node == last_node) {
|
||||
continue;
|
||||
}
|
||||
auto node_iter = std::find(execution_order->begin(), execution_order->end(), node);
|
||||
(void)execution_order->erase(node_iter);
|
||||
}
|
||||
std::sort(op_v.begin(), op_v.end(), comp);
|
||||
auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node);
|
||||
auto node_iter = execution_order->erase(last_node_iter);
|
||||
// all send will insert the end of the last node
|
||||
execution_order->insert(node_iter, op_v.begin(), op_v.end());
|
||||
}
|
||||
|
||||
// Reorder the execution order of receive
|
||||
void ReorderRecv(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
|
||||
auto begin_node = op_v.front();
|
||||
for (auto &node : op_v) {
|
||||
if (node == begin_node) {
|
||||
continue;
|
||||
}
|
||||
auto node_iter = std::find(execution_order->begin(), execution_order->end(), node);
|
||||
(void)execution_order->erase(node_iter);
|
||||
}
|
||||
std::sort(op_v.begin(), op_v.end(), comp);
|
||||
auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node);
|
||||
auto node_iter = execution_order->erase(begin_node_iter);
|
||||
// all receive will insert before the begin node
|
||||
execution_order->insert(node_iter, op_v.begin(), op_v.end());
|
||||
}
|
||||
|
||||
void ReorderSendRecv(std::vector<CNodePtr> *execution_order) {
|
||||
std::vector<CNodePtr> forward_send, forward_recv, backward_send, backward_recv;
|
||||
for (auto &cnode : *execution_order) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) {
|
||||
backward_send.push_back(cnode);
|
||||
continue;
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) {
|
||||
forward_send.push_back(cnode);
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) {
|
||||
backward_recv.push_back(cnode);
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
||||
forward_recv.push_back(cnode);
|
||||
}
|
||||
}
|
||||
if (!forward_send.empty()) {
|
||||
ReorderSend(execution_order, forward_send);
|
||||
}
|
||||
if (!backward_send.empty()) {
|
||||
ReorderSend(execution_order, backward_send);
|
||||
}
|
||||
if (!forward_recv.empty()) {
|
||||
ReorderRecv(execution_order, forward_recv);
|
||||
}
|
||||
if (!backward_recv.empty()) {
|
||||
ReorderRecv(execution_order, backward_recv);
|
||||
}
|
||||
}
|
||||
|
||||
GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||
MS_LOG(INFO) << "Start";
|
||||
std::vector<KernelGraphPtr> all_graphs;
|
||||
|
@ -520,6 +606,11 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
|||
|
||||
// adjust kernel
|
||||
AdjustKernel(root_graph);
|
||||
|
||||
// reorder send/recv
|
||||
auto execution_order = root_graph->execution_order();
|
||||
ReorderSendRecv(&execution_order);
|
||||
root_graph->set_execution_order(execution_order);
|
||||
#if ENABLE_CPU && ENABLE_D
|
||||
InitPsWorker(root_graph);
|
||||
#endif
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "ir/anf.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/comm_manager.h"
|
||||
|
@ -51,12 +52,37 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) {
|
||||
auto node_users = node_users_map[node];
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first;
|
||||
if (!user_node->grad()) {
|
||||
user_node->set_grad(true);
|
||||
SetGradTag(user_node, node_users_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PipelineTransformer::LabelRequiredGradCNode() {
|
||||
auto parameters = root_->parameters();
|
||||
auto node_users_map = manager_->node_users();
|
||||
for (auto parameter : parameters) {
|
||||
if (!ParameterRequireGrad(parameter)) {
|
||||
continue;
|
||||
}
|
||||
SetGradTag(parameter, node_users_map);
|
||||
}
|
||||
}
|
||||
|
||||
void PipelineTransformer::Coloring() {
|
||||
auto need_coloring = true;
|
||||
std::set<int64_t> stage_set;
|
||||
while (need_coloring) {
|
||||
need_coloring = false;
|
||||
for (auto &fg : manager_->func_graphs()) {
|
||||
if (fg == root_) {
|
||||
continue;
|
||||
}
|
||||
auto value_nodes = fg->value_nodes();
|
||||
for (auto &value_pair : value_nodes) {
|
||||
auto node = value_pair.first;
|
||||
|
@ -64,10 +90,12 @@ void PipelineTransformer::Coloring() {
|
|||
continue;
|
||||
}
|
||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||
auto need_grad = graph->get_return()->grad();
|
||||
auto node_users = manager_->node_users()[node];
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||
user_node->set_stage(graph->stage());
|
||||
user_node->set_grad(need_grad);
|
||||
auto user_node_graph = user_node->func_graph();
|
||||
if (graph->stage() != -1) {
|
||||
stage_set.insert(graph->stage());
|
||||
|
@ -90,7 +118,11 @@ void PipelineTransformer::Coloring() {
|
|||
|
||||
void PipelineTransformer::BroadCastColoring() {
|
||||
for (auto &fg : manager_->func_graphs()) {
|
||||
if (fg == root_ || fg->stage() == -1) {
|
||||
continue;
|
||||
}
|
||||
DoBroadCast(fg);
|
||||
SetNoStageNode(fg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,32 +222,17 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
|
|||
while (need_coloring) {
|
||||
need_coloring = false;
|
||||
auto all_nodes = func->nodes();
|
||||
auto node_users = manager_->node_users();
|
||||
for (auto &node : all_nodes) {
|
||||
// only cnode can broadcast color.
|
||||
if (!node->isa<CNode>()) {
|
||||
if (node->isa<CNode>() || node->stage() == -1) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->stage() == -1) {
|
||||
// broadcast from inputs to outputs
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (input->isa<CNode>() && input->stage() == stage_) {
|
||||
cnode->set_stage(input->stage());
|
||||
need_coloring = true;
|
||||
}
|
||||
}
|
||||
} else if (cnode->stage() == stage_) {
|
||||
// broadcast from outputs to inputs
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (input->stage() != -1 || !input->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (prim != nullptr && prim->name() == VIRTUAL_DATA_SET) {
|
||||
continue;
|
||||
}
|
||||
input->set_stage(cnode->stage());
|
||||
auto stage = node->stage();
|
||||
for (auto &user_pair : node_users[node]) {
|
||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||
auto user_node_stage = user_node->stage();
|
||||
if (IsValueNode<FuncGraph>(user_node->input(0)) && stage > user_node_stage) {
|
||||
user_node->set_stage(stage);
|
||||
need_coloring = true;
|
||||
}
|
||||
}
|
||||
|
@ -223,6 +240,16 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
|
|||
}
|
||||
}
|
||||
|
||||
void PipelineTransformer::SetNoStageNode(const FuncGraphPtr &func) {
|
||||
auto all_nodes = func->nodes();
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>() || node->stage() != -1) {
|
||||
continue;
|
||||
}
|
||||
node->set_stage(0);
|
||||
}
|
||||
}
|
||||
|
||||
void PipelineTransformer::HandleSharedParameter() {
|
||||
auto parameters = root_->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
|
@ -412,7 +439,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
|
|||
if (node->isa<Parameter>()) {
|
||||
recv_input = {NewValueNode(recv_op), node};
|
||||
} else {
|
||||
recv_input = {NewValueNode(recv_op), virtual_param_};
|
||||
if (node->grad()) {
|
||||
recv_input = {NewValueNode(recv_op), virtual_param_};
|
||||
} else {
|
||||
auto param = root_->parameters()[0];
|
||||
recv_input = {NewValueNode(recv_op), param};
|
||||
}
|
||||
}
|
||||
auto recv = graph->NewCNode(recv_input);
|
||||
auto node_abstract = node->abstract();
|
||||
|
@ -505,7 +537,11 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
|||
manager_->Replace(graph->output(), out_input[1]);
|
||||
}
|
||||
if (out_input.size() > 2) {
|
||||
auto out_node = graph->NewCNode(out_input);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
make_tuple_inputs.insert(make_tuple_inputs.begin() + 1, out_input.begin() + 2, out_input.end());
|
||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||
std::vector<AnfNodePtr> out_depend_inputs = {out_input[0], out_input[1], make_tuple};
|
||||
auto out_node = graph->NewCNode(out_depend_inputs);
|
||||
manager_->Replace(graph->output(), out_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ class PipelineTransformer {
|
|||
global_rank_(global_rank),
|
||||
per_stage_rank_num_(per_stage_rank_num) {}
|
||||
virtual ~PipelineTransformer() = default;
|
||||
void LabelRequiredGradCNode();
|
||||
void Coloring();
|
||||
void BroadCastColoring();
|
||||
void HandleSharedParameter();
|
||||
|
@ -63,6 +64,7 @@ class PipelineTransformer {
|
|||
int64_t node_stage);
|
||||
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
|
||||
int64_t user_node_stage, int64_t node_stage);
|
||||
void SetNoStageNode(const FuncGraphPtr &func);
|
||||
void CutBorder(const FuncGraphPtr &graph);
|
||||
bool IsStageNode(const CNodePtr &node);
|
||||
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
|
||||
|
|
|
@ -1291,7 +1291,7 @@ std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int3
|
|||
MS_EXCEPTION_IF_NULL(prim_node_anf);
|
||||
PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
||||
|
|
|
@ -90,6 +90,7 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
auto transformer =
|
||||
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
|
||||
// step1: Do color graph
|
||||
transformer->LabelRequiredGradCNode();
|
||||
transformer->Coloring();
|
||||
// step2: Do color broadcast
|
||||
transformer->BroadCastColoring();
|
||||
|
|
|
@ -100,7 +100,8 @@ class AnfNode : public Base {
|
|||
fullname_with_scope_(""),
|
||||
hash_(std::hash<const AnfNode *>()),
|
||||
kernel_info_(nullptr),
|
||||
stage_(-1) {
|
||||
stage_(-1),
|
||||
need_grad_(false) {
|
||||
scope_ = ScopeManager::GetInstance().GetCurrentScope();
|
||||
}
|
||||
|
||||
|
@ -190,6 +191,9 @@ class AnfNode : public Base {
|
|||
int64_t stage() { return stage_; }
|
||||
void set_stage(const int &stage) { stage_ = stage; }
|
||||
|
||||
bool grad() { return need_grad_; }
|
||||
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
|
||||
|
||||
protected:
|
||||
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
|
||||
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
|
||||
|
@ -205,6 +209,7 @@ class AnfNode : public Base {
|
|||
KernelInfoDevicePtr kernel_info_;
|
||||
UserData user_data_;
|
||||
int64_t stage_;
|
||||
bool need_grad_;
|
||||
};
|
||||
|
||||
// CNode represents the complex node with a set of arguments.
|
||||
|
|
|
@ -638,7 +638,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
|||
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(param->debug_info()));
|
||||
(void)new_func_graph->add_parameter();
|
||||
(void)new_func_graph->add_parameter()->set_abstract(param->abstract());
|
||||
});
|
||||
|
||||
Cloner cloner = Cloner();
|
||||
|
|
|
@ -85,6 +85,7 @@ def get_bprop_send(self):
|
|||
shape = self.get_attr_dict()["shape"]
|
||||
dtype = self.get_attr_dict()["dtype"]
|
||||
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group)
|
||||
send_grad.add_prim_attr("backward", True)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = send_grad()
|
||||
|
@ -96,6 +97,7 @@ def get_bprop_send(self):
|
|||
def get_bprop_receive(self):
|
||||
"""Generate bprop for Receive."""
|
||||
receive_grad = Send(self.tag, self.rank, self.group)
|
||||
receive_grad.add_prim_attr("backward", True)
|
||||
depend = P.Depend()
|
||||
cast = P.Cast()
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from ... import context
|
|||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ...communication.management import get_rank, GlobalComm, _get_group
|
||||
from ...communication.management import GlobalComm
|
||||
|
||||
|
||||
class ExtractImagePatches(PrimitiveWithInfer):
|
||||
|
@ -409,7 +409,7 @@ class Send(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.rank = get_rank(_get_group(group))
|
||||
self.rank = dest_rank
|
||||
self.sr_tag = sr_tag
|
||||
self.group = group
|
||||
|
||||
|
@ -465,7 +465,7 @@ class Receive(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.rank = get_rank(_get_group(group))
|
||||
self.rank = src_rank
|
||||
self.tag = sr_tag
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
|
Loading…
Reference in New Issue