From 2822eade04454f910c1f44a79f6f37af35c07abc Mon Sep 17 00:00:00 2001 From: lichen Date: Thu, 12 Jan 2023 15:10:33 +0800 Subject: [PATCH] parallel_support_while --- .../frontend/parallel/parameter_manager.cc | 8 +- .../pipeline_transformer.cc | 6 +- .../ccsrc/frontend/parallel/step_parallel.cc | 157 +++++++----------- .../frontend/parallel/step_parallel_utils.cc | 96 ++++++++--- .../frontend/parallel/step_parallel_utils.h | 5 +- .../ut/python/parallel/test_parallel_while.py | 121 ++++++++++++++ 6 files changed, 269 insertions(+), 124 deletions(-) create mode 100644 tests/ut/python/parallel/test_parallel_while.py diff --git a/mindspore/ccsrc/frontend/parallel/parameter_manager.cc b/mindspore/ccsrc/frontend/parallel/parameter_manager.cc index d773a3b98e3..38dab196c05 100644 --- a/mindspore/ccsrc/frontend/parallel/parameter_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/parameter_manager.cc @@ -174,8 +174,12 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode) // the node is a ref key node return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode); } else if (node->isa()) { + auto param_ptr = node->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); // the node is a parameter node - return FindParameterNodeUsers(node); + if (param_ptr->has_default()) { + return FindParameterNodeUsers(node); + } } return parameter_users_info; @@ -745,7 +749,7 @@ static std::pair FindParameterByFuncGraph(const AnfNodePtr &no MS_EXCEPTION_IF_NULL(fg); auto fg_parameters = fg->parameters(); - auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr); + auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first; auto pre_cnode = pre_node->cast(); for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) { auto res = FindParameter(pre_cnode->input(index), pre_cnode->func_graph()); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index e08a7b4154d..79f8ba8d1ef 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -590,7 +590,7 @@ static std::pair GetShapeType(const AnfNodePtr &node, con AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(node); - auto real_node = GetRealKernelNode(node, -1); + auto real_node = GetRealKernelNode(node, -1).first; if (!real_node->isa()) { return real_node; } @@ -795,7 +795,7 @@ bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) const { // ParameterGraph: graph which return a parameter MS_EXCEPTION_IF_NULL(node); CNodePtr call_node = nullptr; - auto real_kernel = GetRealKernelNode(node, -1, &call_node); + auto real_kernel = GetRealKernelNode(node, -1, &call_node).first; if (call_node != nullptr && real_kernel->isa()) { return true; } @@ -806,7 +806,7 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con int64_t user_stage, const ValuePtr µ, size_t pos, const std::vector &ops) { CNodePtr call_node = nullptr; - auto argument = GetRealKernelNode(node, -1, &call_node); + auto argument = GetRealKernelNode(node, -1, &call_node).first; auto use_cnode = use_node->cast(); MS_EXCEPTION_IF_NULL(use_cnode); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index ce00095b7b0..4f8931155f2 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -476,6 +476,13 @@ static void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution IsPrimitiveCNode(cnode, prim::kPrimMakeList)) { return; } + // Find Redistribution next_nodes + std::vector, int>> next_nodes; + RedistributionNextNode(cnode, manager, node_users_map, -1, -1, &next_nodes); + if (next_nodes.empty()) { + return; + } + // Find Redistribution pre_nodes std::vector pre_nodes; RedistributionPreNode(cnode, manager, &pre_nodes); @@ -483,10 +490,6 @@ static void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution MS_LOG(EXCEPTION) << " Don't support Redistribution has multiple pre_node."; } - // Find Redistribution next_nodes - std::vector, int>> next_nodes; - RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes); - // Insert Redistribution nodes between pre_nodes and next_nodes for (auto &pre_node : pre_nodes) { for (auto &next_node : next_nodes) { @@ -875,65 +878,46 @@ static bool FindPreNodes(const AnfNodePtr &node, std::vector *uniqu return find; } -static void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector *unique_ids, - std::vector *indexes) { - MS_EXCEPTION_IF_NULL(unique_ids); - CNodePtr cnode = root->get_return(); - if (!FindPreNodes(cnode, unique_ids, indexes, 0)) { - MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; - } -} - void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector &all_nodes) { - std::vector last_forward_node_ids; - std::vector last_indexs; auto real_graph = PynativeParallelGraph(root, all_nodes); - FindLastNodesUniqueId(real_graph, &last_forward_node_ids, &last_indexs); - MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; - for (auto &node : all_nodes) { - // here insert virtualoutput node - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()); - if (last_node_iter == last_forward_node_ids.end()) { - continue; - } - for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) { - if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) { - continue; - } - MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: " - << cnode->input(last_indexs[last_node_index])->fullname_with_scope(); - if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { - FuncGraphManagerPtr manager = cnode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_pair = manager->node_users()[cnode].front(); - if (!node_pair.first->isa()) { - MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode"; - } - cnode = node_pair.first->cast(); - last_indexs[last_node_index] = IntToSize(node_pair.second); - } - auto pre_node = cnode->input(last_indexs[last_node_index]); - Shapes shape_outputs = GetNodeShape(pre_node); + auto out_pair = GetRealKernelNode(real_graph->output(), -1, nullptr, false); + auto out_node = out_pair.first; + MS_EXCEPTION_IF_NULL(out_node); + OperatorParams params; + OperatorAttrs attrs; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(VIRTUAL_OUTPUT, args); + if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { + auto tuple = out_node->cast(); + MS_EXCEPTION_IF_NULL(tuple); + for (size_t i = 1; i < tuple->inputs().size(); ++i) { + auto cur_input = tuple->input(i); + Shapes shape_outputs = GetNodeShape(cur_input); if (shape_outputs[0].empty()) { continue; } - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - OperatorParams params; - OperatorAttrs attrs; - OperatorArgs args = std::make_pair(attrs, params); - Operator op = std::make_pair(VIRTUAL_OUTPUT, args); - InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT); - auto virtual_output_node = cnode->input(last_indexs[last_node_index]); - AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone(); + InsertNode(op, tuple, i, cur_input, tuple->func_graph(), VIRTUAL_OUTPUT); + auto virtual_output_abstract = cur_input->abstract()->Clone(); std::shared_ptr virtual_output_shape = std::make_shared(shape_outputs[0]); virtual_output_abstract->set_shape(virtual_output_shape); + auto virtual_output_node = tuple->input(i); virtual_output_node->set_abstract(virtual_output_abstract); } + } else { + Shapes shape_outputs = GetNodeShape(out_node); + if (shape_outputs[0].empty()) { + return; + } + auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT); + auto cur_graph = out_node->cast()->func_graph(); + MS_EXCEPTION_IF_NULL(cur_graph); + auto new_node = cur_graph->NewCNode(node_input); + auto manager = cur_graph->manager(); + (void)manager->Replace(out_node, new_node); + auto virtual_output_abstract = out_node->abstract()->Clone(); + std::shared_ptr virtual_output_shape = std::make_shared(shape_outputs[0]); + virtual_output_abstract->set_shape(virtual_output_shape); + new_node->set_abstract(virtual_output_abstract); } } @@ -1606,14 +1590,15 @@ void ExtractInformation(const std::vector &all_nodes) { } // if reshape's output connect to several primitive, return the first layout found -static std::shared_ptr FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) { +static std::shared_ptr FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape, + int make_tuple_index) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode->func_graph()); FuncGraphManagerPtr manager = cnode->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[cnode]; for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); + auto use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; } @@ -1621,24 +1606,26 @@ static std::shared_ptr FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape = true; continue; } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); - if (node_prim->name() == DEPEND && node_pair.second != 1) { + if (IsPrimitiveCNode(use_apply, prim::kPrimDepend) && node_pair.second != 1) { continue; } + if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) { + make_tuple_index = node_pair.second; + return FindNextLayout(use_apply, next_is_reshape, make_tuple_index); + } if (IsParallelCareNode(use_apply) && use_apply->has_user_data()) { - MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); + if (make_tuple_index != -1) { + node_pair.second = make_tuple_index; + } + MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString(); *next_is_reshape = false; auto layout = GetInputLayoutFromCNode(node_pair); return std::make_shared(layout); } - MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) + MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << " " << IsParallelCareNode(use_apply) << " " << use_apply->has_user_data(); - auto layout_ptr = FindNextLayout(use_apply, next_is_reshape); + auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, -1); if (layout_ptr) { return layout_ptr; } @@ -1791,7 +1778,7 @@ static void ReshapeInit(const std::vector &all_nodes) { reshape_info_ptr->SetInputLayout(*prev_layout_ptr); } bool is_next_reshape = false; - auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape); + auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape, -1); if (next_layout_ptr) { auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); reshape_info_ptr->SetOutputLayout(*next_layout_ptr); @@ -1821,10 +1808,7 @@ static CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) { return cnode; } -static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) { - if (max_depth > MAX_RECURSIVE_DEPTH) { - MS_LOG(EXCEPTION) << "Recursive call is larger than 100000."; - } +static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { LossNodeInfo loss_node_info; MS_EXCEPTION_IF_NULL(func_graph); CNodePtr return_node = func_graph->get_return(); @@ -1832,18 +1816,11 @@ static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_dep if (return_node->size() < 2) { MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2"; } - AnfNodePtr pre_node = return_node->input(1); + auto pre_node_pair = GetRealKernelNode(return_node->input(1), -1, nullptr); + auto pre_node = pre_node_pair.first; MS_EXCEPTION_IF_NULL(pre_node); auto pre_cnode = pre_node->cast(); - pre_cnode = HandleDependLoss(pre_cnode, 0); - if (pre_cnode->input(0)->isa()) { - auto switch_cnode = pre_cnode->input(0)->cast(); - if (IsPrimitiveCNode(switch_cnode, prim::kPrimSwitch)) { - MS_EXCEPTION_IF_NULL(switch_cnode); - auto switch_graph = GetValueNode(switch_cnode->input(2)); - return FindLossCNode(switch_graph, max_depth + 1); - } - } + if (pre_cnode == nullptr || !IsValueNode(pre_cnode->input(0))) { return loss_node_info; } @@ -1859,21 +1836,11 @@ static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_dep return loss_node_info; } - // size of common cnode is larger than 1 - if (pre_cnode->size() < 2) { - MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; - } - // return -> tuple_getitem -> loss - if (current_prim->name() == prim::kTupleGetItem) { - auto tuple_index = GetTupleGetItemIndex(pre_cnode); - AnfNodePtr pre_pre_node = pre_cnode->input(1); - MS_EXCEPTION_IF_NULL(pre_pre_node); - - auto pre_pre_cnode = pre_pre_node->cast(); + if (pre_node_pair.second != -1) { loss_node_info.has_tuple_getitem = true; - loss_node_info.dout_index = tuple_index; - loss_node_info.loss_node = pre_pre_cnode; + loss_node_info.dout_index = pre_node_pair.second; + loss_node_info.loss_node = pre_cnode; return loss_node_info; } @@ -2121,7 +2088,7 @@ static std::vector> GetSensLossPairs(const Fun MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; } auto func_graph = GetValueNode(expect_j_cnode->input(1)); - auto loss_node_info = FindLossCNode(func_graph, 0); + auto loss_node_info = FindLossCNode(func_graph); if (loss_node_info.loss_node == nullptr) { MS_LOG(WARNING) << "Can not find the loss cnode"; continue; @@ -2309,7 +2276,7 @@ std::set ForwardGraph(const FuncGraphPtr &root) { static std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { MS_EXCEPTION_IF_NULL(graph); std::vector root_forward_nodes; - auto loss_cnode = FindLossCNode(graph, 0).loss_node; + auto loss_cnode = FindLossCNode(graph).loss_node; if (loss_cnode == nullptr) { return root_forward_nodes; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc index 1b1d0b7148e..a2cc34f9df3 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc @@ -112,22 +112,27 @@ TensorInfo GetInputsTensorInfo(const std::pair ¶m_info) return tensor_info; } -AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node) { +std::pair GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node, + bool ignore_get_item) { if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) || - IsPrimitiveCNode(node, prim::kPrimCast)) { - return GetRealKernelNode(node->cast()->input(1), get_item_index, call_node); + IsPrimitiveCNode(node, prim::kPrimCast) || IsPrimitiveCNode(node, prim::kPrimVirtualDiv)) { + return GetRealKernelNode(node->cast()->input(1), get_item_index, call_node, ignore_get_item); } - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) && ignore_get_item) { auto cnode = node->cast(); auto cur_get_item_index = LongToInt(GetTupleGetItemIndex(cnode)); auto tuple_getitem_input = cnode->input(1); - auto pass_through_node = GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node); - return GetRealKernelNode(pass_through_node, get_item_index, call_node); + return GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node, ignore_get_item); } if (get_item_index != -1 && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { auto make_tuple_cnode = node->cast(); auto make_tuple_input = make_tuple_cnode->input(LongToSize(get_item_index + 1)); - return GetRealKernelNode(make_tuple_input, -1, call_node); + return GetRealKernelNode(make_tuple_input, -1, call_node, ignore_get_item); + } + if (IsControlFlowNode(node)) { + auto switch_cnode = node->cast()->input(0)->cast(); + auto fg = GetValueNode(switch_cnode->input(3)); + return GetRealKernelNode(fg->output(), get_item_index, call_node, ignore_get_item); } if (node->isa() && IsValueNode(node->cast()->input(0))) { if (call_node != nullptr && *call_node == nullptr) { @@ -135,21 +140,33 @@ AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNo } auto cnode = node->cast(); auto graph = GetValueNode(cnode->input(0)); - auto output = GetRealKernelNode(graph->output(), get_item_index, call_node); + auto output = GetRealKernelNode(graph->output(), get_item_index, call_node, ignore_get_item).first; MS_EXCEPTION_IF_NULL(output); if (output->isa()) { auto parameters = graph->parameters(); auto pos_iter = std::find(parameters.begin(), parameters.end(), output); // If can't find in parameters, the parameter is a fv. if (pos_iter == parameters.end()) { - return output; + return std::make_pair(output, get_item_index); } auto pos = std::distance(parameters.begin(), pos_iter); - return GetRealKernelNode(cnode->input(LongToSize(pos + 1)), -1, call_node); + return GetRealKernelNode(cnode->input(LongToSize(pos + 1)), -1, call_node, ignore_get_item); } - return output; + return std::make_pair(output, get_item_index); } - return node; + return std::make_pair(node, get_item_index); +} + +static bool IsWhileGraph(const FuncGraphPtr &cur_fg, const FuncGraphPtr &fg) { + auto cur_fg_map = cur_fg->func_graph_cnodes_index(); + for (auto &cur_fg_use : cur_fg_map) { + auto temp_node = cur_fg_use.first->first->cast(); + MS_EXCEPTION_IF_NULL(temp_node); + if (temp_node->func_graph() == fg) { + return true; + } + } + return false; } AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { @@ -280,8 +297,13 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) { return tuple_index_value->cast()->value(); } +static bool IsNoNeedRedistribution(const CNodePtr &use_cnode, int use_index) { + return (IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && use_index != 1) || use_cnode->input(0)->isa() || + IsPrimitiveCNode(use_cnode, prim::kPrimUpdateState) || IsPrimitiveCNode(use_cnode, prim::kPrimSwitch); +} + void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, - const NodeUsersMap &node_users_map, int64_t get_item_index, + const NodeUsersMap &node_users_map, int64_t get_item_index, int64_t make_tuple_index, std::vector, int>> *next_nodes) { MS_EXCEPTION_IF_NULL(node); if (node_users_map.count(node) == 0) { @@ -292,30 +314,60 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m auto use_cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(use_cnode); if (IsValueNode(use_cnode->input(0))) { + auto cur_fg = use_cnode->func_graph(); auto fg = GetValueNode(use_cnode->input(0)); MS_EXCEPTION_IF_NULL(fg); + if (IsWhileGraph(cur_fg, fg)) { + continue; + } auto fg_parameters = fg->parameters(); auto param = fg_parameters[IntToSize(node_pair.second - 1)]; MS_EXCEPTION_IF_NULL(param); - RedistributionNextNode(param, manager, node_users_map, get_item_index, next_nodes); + RedistributionNextNode(param, manager, node_users_map, get_item_index, make_tuple_index, next_nodes); + continue; + } + if (IsPrimitiveCNode(use_cnode, prim::kPrimMakeTuple)) { + make_tuple_index = node_pair.second - 1; + RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes); continue; } if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem)) { - get_item_index = LongToInt(GetTupleGetItemIndex(use_cnode)); + auto temp = LongToInt(GetTupleGetItemIndex(use_cnode)); + if (temp != make_tuple_index && make_tuple_index != -1) { + continue; + } + RedistributionNextNode(use_cnode, manager, node_users_map, temp, -1, next_nodes); + continue; + } + if (IsPrimitiveCNode(use_cnode, prim::kPrimReturn)) { + auto fg = use_cnode->func_graph(); + auto fg_map = fg->func_graph_cnodes_index(); + for (auto &fg_use : fg_map) { + auto fg_node = fg_use.first->first->cast(); + constexpr int SWITCH_LAST_INPUT_INDEX = 3; + if (IsWhileGraph(fg, fg_node->func_graph()) && fg_use.first->second == SWITCH_LAST_INPUT_INDEX) { + RedistributionNextNode(fg_node, manager, node_users_map, get_item_index, make_tuple_index, next_nodes); + } + } } // depend, auto monad and control flow op don't need to jump over - if ((IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && node_pair.second != 1) || - IsPrimitiveCNode(use_cnode, prim::kPrimUpdateState) || IsPrimitiveCNode(use_cnode, prim::kPrimSwitch)) { + if (IsNoNeedRedistribution(use_cnode, node_pair.second)) { continue; } if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data()) { + if (make_tuple_index != -1) { + auto real_node = GetRealKernelNode(use_cnode->input(1), -1, nullptr); + if (IsPrimitiveCNode(real_node.first, prim::kPrimMakeTuple)) { + next_nodes->push_back(std::make_pair(std::make_pair(real_node.first, make_tuple_index + 1), get_item_index)); + make_tuple_index = -1; + continue; + } + } next_nodes->push_back(std::make_pair(node_pair, get_item_index)); - } else if (use_cnode->input(0)->isa()) { continue; - } else { - // search recursively - RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes); } + // search recursively + RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes); } } @@ -323,7 +375,7 @@ void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &man std::vector *pre_nodes) { if (IsValueNode(cnode->input(0))) { auto fg = GetValueNode(cnode->input(0)); - auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr); + auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first; if (!pre_node) { return; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h index b86be3c6dea..a42631dbac5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h @@ -66,11 +66,12 @@ TensorInfo GetInputsTensorInfo(const std::pair ¶m_info) AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager); bool IsControlFlowNode(const AnfNodePtr &node); int64_t GetTupleGetItemIndex(const CNodePtr &cnode); -AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr); +std::pair GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, + CNodePtr *call_node = nullptr, bool ignore_get_item = true); void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, std::vector *pre_nodes); void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, - const NodeUsersMap &node_users_map, int64_t get_item_index, + const NodeUsersMap &node_users_map, int64_t get_item_index, int64_t make_tuple_index, std::vector, int>> *next_nodes); // for specific scenarios diff --git a/tests/ut/python/parallel/test_parallel_while.py b/tests/ut/python/parallel/test_parallel_while.py new file mode 100644 index 00000000000..2bd07bfd011 --- /dev/null +++ b/tests/ut/python/parallel/test_parallel_while.py @@ -0,0 +1,121 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.train import Model +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype + + +class DatasetLenet(): + def __init__(self, data, label, length=3): + self.data = data + self.label = label + self.index = 1 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.data, self.label + + @staticmethod + def get_dataset_size(): + return 32 + + @staticmethod + def get_repeat_count(): + return 1 + + @staticmethod + def get_batch_size(): + return 32 + + def create_tuple_iterator(self, num_epochs=1, do_copy=True): + return self + + def reset(self): + self.index = 0 + + +class MatMulCell(nn.Cell): + def __init__(self): + super().__init__() + self.matmul = P.MatMul() + self.relu = P.ReLU().shard(((2, 1),)) + self.weight = Parameter(initializer("ones", [64, 64]), name="param1") + + def construct(self, x): + out = self.matmul(x, self.weight) + out = self.relu(out) + return out + + +class ConcatCell(nn.Cell): + def __init__(self): + super().__init__() + self.concat = P.Concat().shard(((1, 8), (1, 8))) + self.relu = P.ReLU() + + def construct(self, x, y): + out = self.concat((y, x)) + out = self.relu(out) + return out + + +class Net(nn.Cell): + def __init__(self): + super().__init__() + self.matmul = P.MatMul().shard(((2, 4), (4, 1))) + self.weight = Parameter(initializer("ones", [64, 64]), name="param") + self.index = Parameter(Tensor(0, mstype.int32), requires_grad=False) + self.cell1 = MatMulCell() + self.cell2 = ConcatCell() + self.relu = P.ReLU().shard(((8, 1),)) + + def construct(self, x, y): + out = self.matmul(x, self.weight) + while self.index < 3: + out = self.cell1(out) + self.index += 1 + out = self.cell2(out, x) + out = self.relu(out) + return out + + +def test_parallel_while(): + """ + Feature: test parallel while. + Description: while + concat. + Expectation: Successful graph compilation. + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + context.set_auto_parallel_context(device_num=8, global_rank=0) + net = Net() + data = Tensor(np.ones([128, 64]), dtype=ms.float32) + label = Tensor(np.ones([8, 8]), dtype=ms.float32) + dataset = DatasetLenet(data, label, 3) + opt = nn.Lamb(net.trainable_params(), learning_rate=0.01) + model = Model(net, optimizer=opt) + model.train(2, dataset, dataset_sink_mode=False)