From 744e4cbab5f9ae4404c7646dd8cb056d56aa6269 Mon Sep 17 00:00:00 2001 From: lichenever Date: Mon, 21 Jun 2021 21:27:54 +0800 Subject: [PATCH] change_pipeline_shared_param --- .../frontend/parallel/ops_info/ops_utils.h | 1 + .../pipeline_transformer.cc | 44 +++++++++-------- .../pipeline_transformer.h | 4 +- .../ccsrc/frontend/parallel/step_parallel.cc | 49 ++++++++++++------- mindspore/core/utils/parallel_node_check.cc | 2 +- 5 files changed, 59 insertions(+), 41 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 150438c4162..1e700ab2f53 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -380,6 +380,7 @@ constexpr char VIRTUAL_ASSIGN_ADD[] = "_VirtualAssignAdd"; constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad"; constexpr char ACCU_GRAD[] = "accu_grad"; constexpr char PARAMETER_START[] = "parameter_start"; +constexpr char PARAM_INDEX[] = "param_index"; // Parallel don't care constexpr char STRING_EQUAL[] = "string_equal"; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 0f6f115978e..fd287f3e272 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -246,12 +246,12 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { return op_info; } -std::pair PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { +std::pair PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); // Handle Cast and TupleGetitem situation - size_t tensor_info_index = 0; + int tensor_info_index = 0; OperatorInfoPtr op_info; if (IsPrimitiveCNode(node, prim::kPrimReceive)) { op_info = node->user_data(); @@ -259,19 +259,17 @@ std::pair PipelineTransformer::GetOpInfo(const A if (IsPrimitiveCNode(node, prim::kPrimCast)) { cnode = cnode->input(1)->cast(); } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { - tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode)); + tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode)); cnode = cnode->input(1)->cast(); } // Create OperatorInfo to get slice_shape for send/recv MS_EXCEPTION_IF_NULL(cnode); op_info = CreateOpInfo(cnode); } - MS_EXCEPTION_IF_NULL(op_info); - auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index]; - return std::make_pair(op_info, std::make_shared(tensor_info)); + return std::make_pair(op_info, tensor_info_index); } -std::pair PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { +std::pair PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto node_users_map = manager_->node_users(); auto node_users = node_users_map[node]; @@ -322,11 +320,9 @@ std::pair PipelineTransformer::GetParameterPair( continue; } auto op_info = CreateOpInfo(care_node); - MS_EXCEPTION_IF_NULL(op_info); - auto tensor_info = op_info->inputs_tensor_info()[IntToSize(index) - 1]; - return std::make_pair(op_info, std::make_shared(tensor_info)); + return std::make_pair(op_info, index - 1); } - return std::make_pair(nullptr, nullptr); + return std::make_pair(nullptr, 0); } std::vector PipelineTransformer::HandleSharedParameter() { @@ -478,10 +474,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod auto send_op = CreatOpInstance(attrs, SEND, SEND); auto send_node = NewValueNode(send_op); auto prim = GetValueNode(send_node); - std::pair op_info_pair; + std::pair op_info_pair; AnfNodePtr care_node; + TensorInfo tensor_info; if (parameter->isa()) { op_info_pair = GetParameterPair(parameter); + tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second)); } else { if (IsPrimitiveCNode(parameter, prim::kPrimCast)) { auto parameter_cnode = parameter->cast(); @@ -491,13 +489,15 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod } if (care_node->isa()) { op_info_pair = GetParameterPair(care_node); + tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second)); } else { op_info_pair = GetOpInfo(care_node); + tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second)); } } - auto tensor_info = op_info_pair.second; - MS_EXCEPTION_IF_NULL(tensor_info); - auto slice_shape = tensor_info->slice_shape(); + auto index = op_info_pair.second; + auto op_info = op_info_pair.first; + auto slice_shape = tensor_info.slice_shape(); auto shape_type_pair = GetShapeType(parameter, slice_shape); prim->set_attr(SHAPE, shape_type_pair.first); prim->set_attr(DTYPE, shape_type_pair.second); @@ -508,6 +508,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod } else { send->AddPrimalAttr(PIPELINE_PARAM, value); send->AddPrimalAttr(MICRO, value); + send->set_user_data(op_info); + send->AddPrimalAttr(PARAM_INDEX, MakeValue(index)); } OperatorAttrs depend_attrs; auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); @@ -533,23 +535,25 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A } Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag)); Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage)); - std::pair op_info_pair; + std::pair op_info_pair; bool is_param = true; + TensorInfo tensor_info; if (node->isa()) { op_info_pair = GetParameterPair(node); + tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second)); } else { auto care_node = FindPipelineCareNode(node); if (care_node->isa()) { op_info_pair = GetParameterPair(care_node); + tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second)); } else { op_info_pair = GetOpInfo(care_node); + tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second)); is_param = false; } } - auto tensor_info = op_info_pair.second; - MS_EXCEPTION_IF_NULL(tensor_info); - auto tensor_layout = tensor_info->tensor_layout(); - Shape slice_shape = tensor_info->slice_shape(); + auto tensor_layout = tensor_info.tensor_layout(); + Shape slice_shape = tensor_info.slice_shape(); auto shape_type_pair = GetShapeType(node, slice_shape); Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first); Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index 69bedb43410..2606e2e08db 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -72,8 +72,8 @@ class PipelineTransformer { AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector &out_input, const std::string &tag); AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); - std::pair GetOpInfo(const AnfNodePtr &node); - std::pair GetParameterPair(const AnfNodePtr &node); + std::pair GetOpInfo(const AnfNodePtr &node); + std::pair GetParameterPair(const AnfNodePtr &node); OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); bool IsPipelineCareNode(const CNodePtr &cnode); std::pair FindSensNode(); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3ebfaf7728b..596406fa25a 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -611,6 +611,9 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ AnfNodeIndexSet node_set = manager->node_users()[node]; CNodePtr insert_node_new; + if (IsPrimitiveCNode(node, prim::kPrimSend)) { + return; + } if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) { MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node"; return; @@ -1091,9 +1094,6 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap } } - if (IsSomePrimitive(cnode, RECEIVE) && cnode->has_user_data(PIPELINE_PARAM)) { - return std::make_pair(node, false); - } // When not fully use opt shard, allgather and mirror would be both inserted. // Skip allgather here and find parameter recursively. if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) { @@ -1180,6 +1180,9 @@ bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) { } static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { + if (IsPrimitiveCNode(node, prim::kPrimSend)) { + return true; + } if ((node->inputs().size() == 2) && (IsValueNode(node->input(1)))) { MS_LOG(INFO) << "Input is ValueList, skip it."; return false; @@ -1242,6 +1245,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons for (size_t index = 1; index < node_size; ++index) { OperatorVector backward_op = mirror_ops[index - 1]; + if (IsPrimitiveCNode(node, prim::kPrimSend)) { + auto param_index = GetValue(node->GetPrimalAttr(PARAM_INDEX)); + backward_op = mirror_ops[IntToSize(param_index)]; + } if (backward_op.empty()) { continue; } @@ -1271,10 +1278,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons } // not a RefKey std::string mirror_op_name = MirrorOpName(); - if (IsPrimitiveCNode(param_node_pair.first, prim::kPrimReceive)) { - param_ptr = param_node_pair.first->cast()->user_data(PIPELINE_PARAM)->cast(); - param_name = param_ptr->name(); - } AnfNodePtr pre_node = node->input(index); if (!param_node_pair.second) { auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph); @@ -1329,6 +1332,9 @@ void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &dist MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); + if (IsPrimitiveCNode(node, prim::kPrimReceive)) { + return; + } bool is_loss_cnode = std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), [node](const std::pair &element) { return element.second.loss_node == node; }); @@ -2109,7 +2115,7 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini for (auto &node : all_nodes) { auto cnode = node->cast(); - if (!CheckExtractInfomation(cnode)) { + if (!CheckExtractInfomation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) { continue; } @@ -2631,6 +2637,9 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); + if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) { + return; + } OperatorVector forward_op = distribute_operator->forward_op(); if (!forward_op.empty()) { MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name(); @@ -2820,7 +2829,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorisa()) { auto cnode = node->cast(); // the make_tuple is parallel care node, but it may have not operator info - if (!IsParallelCareNode(cnode) || !cnode->has_user_data() || cnode->HasPrimalAttr(PIPELINE_PARAM)) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { continue; } @@ -2828,15 +2837,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector &p OperatorInfoPtr op_info = user_cnode->user_data(); MS_EXCEPTION_IF_NULL(op_info); - size_t input_tensor_info_size = op_info->inputs_tensor_info().size(); - if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) { - MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size - << ", but the index is " << user_input_index - 1; + TensorInfo tensor_info; + if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) { + auto param_index = IntToSize(GetValue(user_cnode->GetPrimalAttr(PARAM_INDEX))); + tensor_info = op_info->inputs_tensor_info()[param_index]; + } else { + size_t input_tensor_info_size = op_info->inputs_tensor_info().size(); + if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) { + MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size + << ", but the index is " << user_input_index - 1; + } + tensor_info = op_info->inputs_tensor_info()[user_input_index - 1]; } - TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1]; ParameterSliceInfo parameter_slice_info; parameter_slice_info.slice_shape = tensor_info.slice_shape(); diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index ee9f64912b6..7130cd396ae 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -30,7 +30,7 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", - "stop_gradient", "Send", "UpdateState", "Load"}; + "stop_gradient", "UpdateState", "Load"}; static const std::set ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather}; static const std::set TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend}; // clang-format on