optimizer_pipline_split

This commit is contained in:
lichenever 2020-12-18 16:56:07 +08:00
parent 1cbef74372
commit 9595502278
9 changed files with 168 additions and 31 deletions

View File

@ -69,6 +69,8 @@ namespace mindspore {
namespace session {
const size_t kInvalidIndex = SIZE_MAX;
constexpr size_t kReturnDataIndex = 1;
constexpr char SR_TAG[] = "sr_tag";
constexpr char BACKWARD[] = "backward";
namespace {
void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") {
MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
@ -460,6 +462,90 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
return graph_id;
}
bool IsBackward(const CNodePtr &cnode) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
return prim->HasAttr(BACKWARD);
}
// compare the value of send/recv sr_tag
bool comp(const CNodePtr &node1, const CNodePtr &node2) {
auto prim1 = GetValueNode<PrimitivePtr>(node1->input(0));
MS_EXCEPTION_IF_NULL(prim1);
auto prim2 = GetValueNode<PrimitivePtr>(node1->input(0));
MS_EXCEPTION_IF_NULL(prim2);
auto sr_tag_value1 = prim1->GetAttr(SR_TAG);
MS_EXCEPTION_IF_NULL(sr_tag_value1);
auto sr_tag_value2 = prim2->GetAttr(SR_TAG);
MS_EXCEPTION_IF_NULL(sr_tag_value2);
auto sr_tag1 = GetValue<int64_t>(sr_tag_value1);
auto sr_tag2 = GetValue<int64_t>(sr_tag_value2);
return sr_tag1 < sr_tag2;
}
// Reorder the execution order of send
void ReorderSend(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
auto last_node = op_v.back();
for (auto &node : op_v) {
if (node == last_node) {
continue;
}
auto node_iter = std::find(execution_order->begin(), execution_order->end(), node);
(void)execution_order->erase(node_iter);
}
std::sort(op_v.begin(), op_v.end(), comp);
auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node);
auto node_iter = execution_order->erase(last_node_iter);
// all send will insert the end of the last node
execution_order->insert(node_iter, op_v.begin(), op_v.end());
}
// Reorder the execution order of receive
void ReorderRecv(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
auto begin_node = op_v.front();
for (auto &node : op_v) {
if (node == begin_node) {
continue;
}
auto node_iter = std::find(execution_order->begin(), execution_order->end(), node);
(void)execution_order->erase(node_iter);
}
std::sort(op_v.begin(), op_v.end(), comp);
auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node);
auto node_iter = execution_order->erase(begin_node_iter);
// all receive will insert before the begin node
execution_order->insert(node_iter, op_v.begin(), op_v.end());
}
void ReorderSendRecv(std::vector<CNodePtr> *execution_order) {
std::vector<CNodePtr> forward_send, forward_recv, backward_send, backward_recv;
for (auto &cnode : *execution_order) {
if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) {
backward_send.push_back(cnode);
continue;
} else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) {
forward_send.push_back(cnode);
continue;
}
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) {
backward_recv.push_back(cnode);
} else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
forward_recv.push_back(cnode);
}
}
if (!forward_send.empty()) {
ReorderSend(execution_order, forward_send);
}
if (!backward_send.empty()) {
ReorderSend(execution_order, backward_send);
}
if (!forward_recv.empty()) {
ReorderRecv(execution_order, forward_recv);
}
if (!backward_recv.empty()) {
ReorderRecv(execution_order, backward_recv);
}
}
GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "Start";
std::vector<KernelGraphPtr> all_graphs;
@ -520,6 +606,11 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
// adjust kernel
AdjustKernel(root_graph);
// reorder send/recv
auto execution_order = root_graph->execution_order();
ReorderSendRecv(&execution_order);
root_graph->set_execution_order(execution_order);
#if ENABLE_CPU && ENABLE_D
InitPsWorker(root_graph);
#endif

View File

@ -28,6 +28,7 @@
#include "frontend/parallel/context.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/node_check.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "ir/anf.h"
#include "base/core_ops.h"
#include "utils/comm_manager.h"
@ -51,12 +52,37 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
return false;
}
static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) {
auto node_users = node_users_map[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
if (!user_node->grad()) {
user_node->set_grad(true);
SetGradTag(user_node, node_users_map);
}
}
}
void PipelineTransformer::LabelRequiredGradCNode() {
auto parameters = root_->parameters();
auto node_users_map = manager_->node_users();
for (auto parameter : parameters) {
if (!ParameterRequireGrad(parameter)) {
continue;
}
SetGradTag(parameter, node_users_map);
}
}
void PipelineTransformer::Coloring() {
auto need_coloring = true;
std::set<int64_t> stage_set;
while (need_coloring) {
need_coloring = false;
for (auto &fg : manager_->func_graphs()) {
if (fg == root_) {
continue;
}
auto value_nodes = fg->value_nodes();
for (auto &value_pair : value_nodes) {
auto node = value_pair.first;
@ -64,10 +90,12 @@ void PipelineTransformer::Coloring() {
continue;
}
auto graph = GetValueNode<FuncGraphPtr>(node);
auto need_grad = graph->get_return()->grad();
auto node_users = manager_->node_users()[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first->cast<CNodePtr>();
user_node->set_stage(graph->stage());
user_node->set_grad(need_grad);
auto user_node_graph = user_node->func_graph();
if (graph->stage() != -1) {
stage_set.insert(graph->stage());
@ -90,7 +118,11 @@ void PipelineTransformer::Coloring() {
void PipelineTransformer::BroadCastColoring() {
for (auto &fg : manager_->func_graphs()) {
if (fg == root_ || fg->stage() == -1) {
continue;
}
DoBroadCast(fg);
SetNoStageNode(fg);
}
}
@ -190,32 +222,17 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
while (need_coloring) {
need_coloring = false;
auto all_nodes = func->nodes();
auto node_users = manager_->node_users();
for (auto &node : all_nodes) {
// only cnode can broadcast color.
if (!node->isa<CNode>()) {
if (node->isa<CNode>() || node->stage() == -1) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->stage() == -1) {
// broadcast from inputs to outputs
for (auto &input : cnode->inputs()) {
if (input->isa<CNode>() && input->stage() == stage_) {
cnode->set_stage(input->stage());
need_coloring = true;
}
}
} else if (cnode->stage() == stage_) {
// broadcast from outputs to inputs
for (auto &input : cnode->inputs()) {
if (input->stage() != -1 || !input->isa<CNode>()) {
continue;
}
auto input_cnode = input->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (prim != nullptr && prim->name() == VIRTUAL_DATA_SET) {
continue;
}
input->set_stage(cnode->stage());
auto stage = node->stage();
for (auto &user_pair : node_users[node]) {
auto user_node = user_pair.first->cast<CNodePtr>();
auto user_node_stage = user_node->stage();
if (IsValueNode<FuncGraph>(user_node->input(0)) && stage > user_node_stage) {
user_node->set_stage(stage);
need_coloring = true;
}
}
@ -223,6 +240,16 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
}
}
void PipelineTransformer::SetNoStageNode(const FuncGraphPtr &func) {
auto all_nodes = func->nodes();
for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() != -1) {
continue;
}
node->set_stage(0);
}
}
void PipelineTransformer::HandleSharedParameter() {
auto parameters = root_->parameters();
for (auto &parameter : parameters) {
@ -412,7 +439,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
if (node->isa<Parameter>()) {
recv_input = {NewValueNode(recv_op), node};
} else {
recv_input = {NewValueNode(recv_op), virtual_param_};
if (node->grad()) {
recv_input = {NewValueNode(recv_op), virtual_param_};
} else {
auto param = root_->parameters()[0];
recv_input = {NewValueNode(recv_op), param};
}
}
auto recv = graph->NewCNode(recv_input);
auto node_abstract = node->abstract();
@ -505,7 +537,11 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
manager_->Replace(graph->output(), out_input[1]);
}
if (out_input.size() > 2) {
auto out_node = graph->NewCNode(out_input);
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.begin() + 1, out_input.begin() + 2, out_input.end());
auto make_tuple = graph->NewCNode(make_tuple_inputs);
std::vector<AnfNodePtr> out_depend_inputs = {out_input[0], out_input[1], make_tuple};
auto out_node = graph->NewCNode(out_depend_inputs);
manager_->Replace(graph->output(), out_node);
}
}

View File

@ -47,6 +47,7 @@ class PipelineTransformer {
global_rank_(global_rank),
per_stage_rank_num_(per_stage_rank_num) {}
virtual ~PipelineTransformer() = default;
void LabelRequiredGradCNode();
void Coloring();
void BroadCastColoring();
void HandleSharedParameter();
@ -63,6 +64,7 @@ class PipelineTransformer {
int64_t node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage);
void SetNoStageNode(const FuncGraphPtr &func);
void CutBorder(const FuncGraphPtr &graph);
bool IsStageNode(const CNodePtr &node);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);

View File

@ -1291,7 +1291,7 @@ std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int3
MS_EXCEPTION_IF_NULL(prim_node_anf);
PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
if (node_prim->name() == DEPEND && node_pair.second != 1) {
if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {

View File

@ -90,6 +90,7 @@ bool PipelineSplit(const ResourcePtr &res) {
auto transformer =
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
// step1: Do color graph
transformer->LabelRequiredGradCNode();
transformer->Coloring();
// step2: Do color broadcast
transformer->BroadCastColoring();

View File

@ -100,7 +100,8 @@ class AnfNode : public Base {
fullname_with_scope_(""),
hash_(std::hash<const AnfNode *>()),
kernel_info_(nullptr),
stage_(-1) {
stage_(-1),
need_grad_(false) {
scope_ = ScopeManager::GetInstance().GetCurrentScope();
}
@ -190,6 +191,9 @@ class AnfNode : public Base {
int64_t stage() { return stage_; }
void set_stage(const int &stage) { stage_ = stage; }
bool grad() { return need_grad_; }
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
@ -205,6 +209,7 @@ class AnfNode : public Base {
KernelInfoDevicePtr kernel_info_;
UserData user_data_;
int64_t stage_;
bool need_grad_;
};
// CNode represents the complex node with a set of arguments.

View File

@ -638,7 +638,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
MS_EXCEPTION_IF_NULL(param);
TraceGuard trace_guard(std::make_shared<TraceCopy>(param->debug_info()));
(void)new_func_graph->add_parameter();
(void)new_func_graph->add_parameter()->set_abstract(param->abstract());
});
Cloner cloner = Cloner();

View File

@ -85,6 +85,7 @@ def get_bprop_send(self):
shape = self.get_attr_dict()["shape"]
dtype = self.get_attr_dict()["dtype"]
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group)
send_grad.add_prim_attr("backward", True)
def bprop(x, out, dout):
dx = send_grad()
@ -96,6 +97,7 @@ def get_bprop_send(self):
def get_bprop_receive(self):
"""Generate bprop for Receive."""
receive_grad = Send(self.tag, self.rank, self.group)
receive_grad.add_prim_attr("backward", True)
depend = P.Depend()
cast = P.Cast()

View File

@ -21,7 +21,7 @@ from ... import context
from ...common import dtype as mstype
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce
from ...communication.management import get_rank, GlobalComm, _get_group
from ...communication.management import GlobalComm
class ExtractImagePatches(PrimitiveWithInfer):
@ -409,7 +409,7 @@ class Send(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.rank = dest_rank
self.sr_tag = sr_tag
self.group = group
@ -465,7 +465,7 @@ class Receive(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.rank = src_rank
self.tag = sr_tag
self.shape = shape
self.dtype = dtype