!18666 [AutoParallel]change pipeline split shared parameter

Merge pull request !18666 from lichen/change_pipeline_shared_param
This commit is contained in:
i-robot 2021-06-22 01:19:59 +00:00 committed by Gitee
commit ead6c37b3c
5 changed files with 59 additions and 41 deletions

View File

@ -380,6 +380,7 @@ constexpr char VIRTUAL_ASSIGN_ADD[] = "_VirtualAssignAdd";
constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad";
constexpr char ACCU_GRAD[] = "accu_grad";
constexpr char PARAMETER_START[] = "parameter_start";
constexpr char PARAM_INDEX[] = "param_index";
// Parallel don't care
constexpr char STRING_EQUAL[] = "string_equal";

View File

@ -246,12 +246,12 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
return op_info;
}
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Handle Cast and TupleGetitem situation
size_t tensor_info_index = 0;
int tensor_info_index = 0;
OperatorInfoPtr op_info;
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
op_info = node->user_data<OperatorInfo>();
@ -259,19 +259,17 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
if (IsPrimitiveCNode(node, prim::kPrimCast)) {
cnode = cnode->input(1)->cast<CNodePtr>();
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode));
cnode = cnode->input(1)->cast<CNodePtr>();
}
// Create OperatorInfo to get slice_shape for send/recv
MS_EXCEPTION_IF_NULL(cnode);
op_info = CreateOpInfo(cnode);
}
MS_EXCEPTION_IF_NULL(op_info);
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
return std::make_pair(op_info, tensor_info_index);
}
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto node_users_map = manager_->node_users();
auto node_users = node_users_map[node];
@ -322,11 +320,9 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(
continue;
}
auto op_info = CreateOpInfo(care_node);
MS_EXCEPTION_IF_NULL(op_info);
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(index) - 1];
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
return std::make_pair(op_info, index - 1);
}
return std::make_pair(nullptr, nullptr);
return std::make_pair(nullptr, 0);
}
std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
@ -478,10 +474,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
auto send_op = CreatOpInstance(attrs, SEND, SEND);
auto send_node = NewValueNode(send_op);
auto prim = GetValueNode<PrimitivePtr>(send_node);
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
std::pair<OperatorInfoPtr, int> op_info_pair;
AnfNodePtr care_node;
TensorInfo tensor_info;
if (parameter->isa<Parameter>()) {
op_info_pair = GetParameterPair(parameter);
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
} else {
if (IsPrimitiveCNode(parameter, prim::kPrimCast)) {
auto parameter_cnode = parameter->cast<CNodePtr>();
@ -491,13 +489,15 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
}
if (care_node->isa<Parameter>()) {
op_info_pair = GetParameterPair(care_node);
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
} else {
op_info_pair = GetOpInfo(care_node);
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
}
}
auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape();
auto index = op_info_pair.second;
auto op_info = op_info_pair.first;
auto slice_shape = tensor_info.slice_shape();
auto shape_type_pair = GetShapeType(parameter, slice_shape);
prim->set_attr(SHAPE, shape_type_pair.first);
prim->set_attr(DTYPE, shape_type_pair.second);
@ -508,6 +508,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
} else {
send->AddPrimalAttr(PIPELINE_PARAM, value);
send->AddPrimalAttr(MICRO, value);
send->set_user_data<OperatorInfo>(op_info);
send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
}
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
@ -533,23 +535,25 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
}
Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
std::pair<OperatorInfoPtr, int> op_info_pair;
bool is_param = true;
TensorInfo tensor_info;
if (node->isa<Parameter>()) {
op_info_pair = GetParameterPair(node);
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
} else {
auto care_node = FindPipelineCareNode(node);
if (care_node->isa<Parameter>()) {
op_info_pair = GetParameterPair(care_node);
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
} else {
op_info_pair = GetOpInfo(care_node);
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
is_param = false;
}
}
auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info);
auto tensor_layout = tensor_info->tensor_layout();
Shape slice_shape = tensor_info->slice_shape();
auto tensor_layout = tensor_info.tensor_layout();
Shape slice_shape = tensor_info.slice_shape();
auto shape_type_pair = GetShapeType(node, slice_shape);
Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);

View File

@ -72,8 +72,8 @@ class PipelineTransformer {
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
const std::string &tag);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
bool IsPipelineCareNode(const CNodePtr &cnode);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();

View File

@ -611,6 +611,9 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
AnfNodeIndexSet node_set = manager->node_users()[node];
CNodePtr insert_node_new;
if (IsPrimitiveCNode(node, prim::kPrimSend)) {
return;
}
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
return;
@ -1091,9 +1094,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
}
}
if (IsSomePrimitive(cnode, RECEIVE) && cnode->has_user_data(PIPELINE_PARAM)) {
return std::make_pair(node, false);
}
// When not fully use opt shard, allgather and mirror would be both inserted.
// Skip allgather here and find parameter recursively.
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
@ -1180,6 +1180,9 @@ bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
}
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
if (IsPrimitiveCNode(node, prim::kPrimSend)) {
return true;
}
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it.";
return false;
@ -1242,6 +1245,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
for (size_t index = 1; index < node_size; ++index) {
OperatorVector backward_op = mirror_ops[index - 1];
if (IsPrimitiveCNode(node, prim::kPrimSend)) {
auto param_index = GetValue<int>(node->GetPrimalAttr(PARAM_INDEX));
backward_op = mirror_ops[IntToSize(param_index)];
}
if (backward_op.empty()) {
continue;
}
@ -1271,10 +1278,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
}
// not a RefKey
std::string mirror_op_name = MirrorOpName();
if (IsPrimitiveCNode(param_node_pair.first, prim::kPrimReceive)) {
param_ptr = param_node_pair.first->cast<CNodePtr>()->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
param_name = param_ptr->name();
}
AnfNodePtr pre_node = node->input(index);
if (!param_node_pair.second) {
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
@ -1329,6 +1332,9 @@ void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &dist
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(node);
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
return;
}
bool is_loss_cnode =
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
[node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
@ -2109,7 +2115,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if (!CheckExtractInfomation(cnode)) {
if (!CheckExtractInfomation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) {
continue;
}
@ -2631,6 +2637,9 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(cnode);
if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
return;
}
OperatorVector forward_op = distribute_operator->forward_op();
if (!forward_op.empty()) {
MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name();
@ -2820,7 +2829,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
// the make_tuple is parallel care node, but it may have not operator info
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || cnode->HasPrimalAttr(PIPELINE_PARAM)) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
continue;
}
@ -2828,15 +2837,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(distribute_operator);
// insert forward ops
if (!IsSomePrimitive(cnode, RECEIVE)) {
InsertForwardOps(distribute_operator, cnode);
}
InsertForwardOps(distribute_operator, cnode);
// insert redistribution ops
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
// insert backward ops
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
if (has_backward) {
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
}
@ -3414,12 +3421,18 @@ ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> &p
OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
MS_EXCEPTION_IF_NULL(op_info);
size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
<< ", but the index is " << user_input_index - 1;
TensorInfo tensor_info;
if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
tensor_info = op_info->inputs_tensor_info()[param_index];
} else {
size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
<< ", but the index is " << user_input_index - 1;
}
tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
}
TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
ParameterSliceInfo parameter_slice_info;
parameter_slice_info.slice_shape = tensor_info.slice_shape();

View File

@ -30,7 +30,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "Send", "UpdateState", "Load"};
"stop_gradient", "UpdateState", "Load"};
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};
// clang-format on