From db5d508356a771eac4e8848cbb1e853d0d9a2400 Mon Sep 17 00:00:00 2001 From: lichenever Date: Sun, 16 May 2021 14:55:47 +0800 Subject: [PATCH] pipeline_split_adapt_master --- .../insert_tensor_move_for_hccl_op.cc | 4 + .../ccsrc/backend/session/ascend_session.cc | 88 --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 8 +- mindspore/ccsrc/frontend/optimizer/irpass.h | 6 +- .../optimizer/irpass/special_op_eliminate.h | 57 ++ .../graph_util/pipeline_split_utils.cc | 634 ++++++++++++++++++ .../graph_util/pipeline_split_utils.h | 68 ++ .../parallel/ops_info/operator_info.cc | 26 +- .../parallel/ops_info/operator_info.h | 1 + .../frontend/parallel/ops_info/ops_utils.h | 19 +- .../pipeline_transformer.cc | 569 +++++++++------- .../pipeline_transformer.h | 29 +- .../ccsrc/frontend/parallel/step_parallel.cc | 108 +-- .../ccsrc/frontend/parallel/step_parallel.h | 6 +- mindspore/ccsrc/pipeline/jit/pass.cc | 4 +- .../ccsrc/pipeline/jit/pipeline_split.cc | 8 +- mindspore/core/base/core_ops.h | 3 + mindspore/nn/wrap/cell_wrapper.py | 90 +++ mindspore/ops/_grad/grad_comm_ops.py | 90 ++- mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/_inner_ops.py | 6 +- mindspore/ops/operations/comm_ops.py | 66 ++ mindspore/train/amp.py | 9 +- .../ut/python/parallel/test_pipeline_split.py | 19 +- 24 files changed, 1505 insertions(+), 417 deletions(-) create mode 100755 mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc create mode 100755 mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.cc index e615359097b..d1dd855942b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.cc @@ -63,6 +63,10 @@ bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(cur_node); + if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) { + return false; + } + // when input is a parameter or is a value node if (IsParameterOrValueNode(input)) { return true; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 8422d133b4a..0f30f2fed77 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -192,90 +192,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr } } -bool IsBackward(const CNodePtr &cnode) { - auto prim = GetValueNode(cnode->input(0)); - return prim->HasAttr(BACKWARD); -} - -// compare the value of send/recv sr_tag -bool comp(const CNodePtr &node1, const CNodePtr &node2) { - auto prim1 = GetValueNode(node1->input(0)); - MS_EXCEPTION_IF_NULL(prim1); - auto prim2 = GetValueNode(node1->input(0)); - MS_EXCEPTION_IF_NULL(prim2); - auto sr_tag_value1 = prim1->GetAttr(SR_TAG); - MS_EXCEPTION_IF_NULL(sr_tag_value1); - auto sr_tag_value2 = prim2->GetAttr(SR_TAG); - MS_EXCEPTION_IF_NULL(sr_tag_value2); - auto sr_tag1 = GetValue(sr_tag_value1); - auto sr_tag2 = GetValue(sr_tag_value2); - return sr_tag1 < sr_tag2; -} - -// Reorder the execution order of send -void ReorderSend(std::vector *execution_order, std::vector op_v) { - auto last_node = op_v.back(); - for (auto &node : op_v) { - if (node == last_node) { - continue; - } - auto iter = std::find(execution_order->begin(), execution_order->end(), node); - (void)execution_order->erase(iter); - } - std::sort(op_v.begin(), op_v.end(), comp); - auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node); - auto node_iter = execution_order->erase(last_node_iter); - // all send will insert the end of the last node - execution_order->insert(node_iter, op_v.begin(), op_v.end()); -} - -// Reorder the execution order of receive -void ReorderRecv(std::vector *execution_order, std::vector op_v) { - auto begin_node = op_v.front(); - for (auto &node : op_v) { - if (node == begin_node) { - continue; - } - auto iter = std::find(execution_order->begin(), execution_order->end(), node); - (void)execution_order->erase(iter); - } - std::sort(op_v.begin(), op_v.end(), comp); - auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node); - auto node_iter = execution_order->erase(begin_node_iter); - // all receive will insert before the begin node - execution_order->insert(node_iter, op_v.begin(), op_v.end()); -} - -void ReorderSendRecv(std::vector *execution_order) { - std::vector forward_send, forward_recv, backward_send, backward_recv; - for (auto &cnode : *execution_order) { - if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) { - backward_send.push_back(cnode); - continue; - } else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) { - forward_send.push_back(cnode); - continue; - } - if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) { - backward_recv.push_back(cnode); - } else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) { - forward_recv.push_back(cnode); - } - } - if (!forward_send.empty()) { - ReorderSend(execution_order, forward_send); - } - if (!backward_send.empty()) { - ReorderSend(execution_order, backward_send); - } - if (!forward_recv.empty()) { - ReorderRecv(execution_order, forward_recv); - } - if (!backward_recv.empty()) { - ReorderRecv(execution_order, backward_recv); - } -} - size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Load kInputCtrlTensors"; @@ -510,10 +426,6 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { // adjust kernel AdjustKernel(root_graph); - // reorder send/recv - auto execution_order = root_graph->execution_order(); - ReorderSendRecv(&execution_order); - root_graph->set_execution_order(execution_order); #if ENABLE_CPU && ENABLE_D InitPsWorker(root_graph); #endif diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 7f43ce8f048..37403d28045 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -206,8 +206,14 @@ OptimizeIRPassLib::OptimizeIRPassLib() { virtual_output_eliminate_ = MakeSubstitution(std::make_shared(), "virtual_output_eliminate", prim::kPrimVirtualOutput); - // Receive + // PipelineSplit receive_eliminate_ = MakeSubstitution(std::make_shared(), "receive_eliminate", prim::kPrimReceive); + virtual_accu_grad_ = + MakeSubstitution(std::make_shared(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad); + virtual_assign_add_ = + MakeSubstitution(std::make_shared(), "virtual_assign_add", prim::kPrimVirtualAssignAdd); + mirror_micro_step_ = + MakeSubstitution(std::make_shared(), "mirror_micro_step", prim::kPrimMirrorMicroStep); // Convert print_tuple_wrapper_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 7e5b6384585..7bf375fb50b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -120,8 +120,12 @@ class OptimizeIRPassLib { // virtual output SubstitutionPtr virtual_output_eliminate_; - // Receive + + // PipelineSplit SubstitutionPtr receive_eliminate_; + SubstitutionPtr virtual_accu_grad_; + SubstitutionPtr virtual_assign_add_; + SubstitutionPtr mirror_micro_step_; // Convert SubstitutionPtr print_tuple_wrapper_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index c1e7212508e..37d9a64bb4c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -136,6 +136,63 @@ class ReceiveEliminater : public AnfVisitor { void Visit(const AnfNodePtr &) override {} }; +class VirtualAssignAddEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + + return inputs[1]; + } + + private: + AnfNodePtr x_{nullptr}; +}; + +class VirtualAccuGradEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + + return inputs[1]; + } + + private: + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimMirrorMicroStep, X, Z} -> X +class MirrorMicroStepEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + + return inputs[1]; + } + + void Visit(const AnfNodePtr &) override {} +}; + // {prim::kPrimSameTypeShape, X, Y} -> X class SameEliminater : public AnfVisitor { public: diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc new file mode 100755 index 00000000000..5b6afa97231 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -0,0 +1,634 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "frontend/parallel/graph_util/pipeline_split_utils.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "base/core_ops.h" +#include "ir/value.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/step_parallel.h" + +namespace mindspore { +namespace parallel { +AnfNodePtr FindAccuGrad(const CNodePtr &cnode) { + auto pre_node = cnode->input(1); + while (true) { + if (pre_node->isa()) { + return pre_node; + } else { + if (pre_node->isa()) { + auto pre_cnode = pre_node->cast(); + pre_node = pre_cnode->input(1); + } else { + return nullptr; + } + } + } + return nullptr; +} + +bool IsLastStage() { + MS_EXCEPTION_IF_NULL(g_device_manager); + auto stage_num = g_device_manager->stage_num(); + auto stage_id = g_device_manager->stage_id(); + return ((stage_num - 1) == stage_id); +} + +void SetStridedSliceStrategy(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) { + return; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + PrimitivePtr prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + int64_t dev_num = 1; + auto attrs_temp = prim->attrs(); + std::vector shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape"; + } + std::vector elements; + for (size_t i = 0; i < shape_list[0].size(); i++) { + if (shape_list[0][i].empty()) { + MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; + } + Dimensions input_strategy = {dev_num}; + for (size_t j = 1; j < shape_list[0][i].size(); j++) { + input_strategy.push_back(1); + } + elements.push_back(MakeValue(input_strategy)); + } + ValueTuplePtr strategy = std::make_shared(elements); + attrs_temp[STRATEGY] = strategy; + (void)prim->SetAttrs(attrs_temp); +} + +void InsertVirtualAssignAdd(const std::pair &node_user, const FuncGraphManagerPtr &manager, + const AnfNodePtr &accu_parameter) { + auto cnode = node_user.first->cast(); + if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() || + ((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) && + ParallelContext::GetInstance()->enable_parallel_optimizer())) { + return; + } + auto prim = GetCNodePrimitive(cnode); + if (prim == nullptr) { + MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd"; + return; + } + OperatorAttrs attrs; + auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD); + auto value_node = NewValueNode(py_instance); + std::vector virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter}; + auto graph = cnode->func_graph(); + auto virtual_node = graph->NewCNode(virtual_node_input); + manager->SetEdge(cnode, node_user.second, virtual_node); +} + +void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m) { + auto cnode = recv->cast(); + MS_EXCEPTION_IF_NULL(cnode); + OperatorAttrs attrs; + auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD); + auto value_node = NewValueNode(py_instance); + std::vector virtual_node_input = {value_node, recv, param}; + auto graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto virtual_node = graph->NewCNode(virtual_node_input); + manager->Replace(recv, virtual_node); +} + +AnfNodePtr FindGradAccuParameter(const std::vector ¶meters, const std::string &name) { + for (auto ¶meter : parameters) { + auto param_ptr = parameter->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); + if (param_ptr->name() == name) { + continue; + } + auto expect_name = "accu_grads." + name; + if (param_ptr->name() == expect_name) { + return parameter; + } + } + return nullptr; +} + +void HandleReceiveParam(const FuncGraphPtr &root, const std::vector &all_nodes) { + auto parameters = root->parameters(); + auto node_users_map = root->manager()->node_users(); + for (auto &node : all_nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimReceive)) { + continue; + } + auto cnode = node->cast(); + if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) { + continue; + } + auto parameter_ptr = cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(parameter_ptr); + auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name()); + if (!accu_parameter) { + continue; + } + auto node_users = node_users_map[node]; + for (auto &temp_user : node_users) { + auto temp_node = temp_user.first; + if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) { + temp_node = node_users_map[temp_node].begin()->first; + } + if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) { + auto node_set = node_users_map[temp_node]; + for (auto &node_user : node_set) { + InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter); + } + } else { + InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter); + } + } + InsertVirtualAccuGrad(node, root->manager(), accu_parameter); + } +} + +void AddVirtualAssignAdd(const FuncGraphPtr &root) { + auto parameters = root->parameters(); + auto node_users_map = root->manager()->node_users(); + for (auto ¶meter : parameters) { + auto parameter_ptr = parameter->cast(); + auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name()); + if (!accu_parameter) { + continue; + } + auto node_users = node_users_map[parameter]; + for (auto &temp_user : node_users) { + auto temp_node = temp_user.first; + if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) { + temp_node = node_users_map[temp_node].begin()->first; + } + if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) { + auto node_set = node_users_map[temp_node]; + for (auto &node_user : node_set) { + InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter); + } + } else { + InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter); + } + } + } +} + +bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + auto cnode1 = node1->cast(); + auto cnode2 = node2->cast(); + MS_EXCEPTION_IF_NULL(cnode1); + MS_EXCEPTION_IF_NULL(cnode2); + auto micro1 = cnode1->GetPrimalAttr(MICRO); + auto micro2 = cnode2->GetPrimalAttr(MICRO); + MS_EXCEPTION_IF_NULL(micro1); + MS_EXCEPTION_IF_NULL(micro2); + auto micro1_value = GetValue(micro1); + auto micro2_value = GetValue(micro2); + if (micro1_value == micro2_value) { + auto prim1 = GetCNodePrimitive(cnode1); + auto prim2 = GetCNodePrimitive(cnode2); + MS_EXCEPTION_IF_NULL(prim1); + MS_EXCEPTION_IF_NULL(prim2); + auto rank_tag1 = prim1->GetAttr(SRC_RANK); + auto rank_tag2 = prim2->GetAttr(SRC_RANK); + if (rank_tag1 == nullptr) { + rank_tag1 = prim1->GetAttr(DEST_RANK); + } + if (rank_tag2 == nullptr) { + rank_tag2 = prim2->GetAttr(DEST_RANK); + } + MS_EXCEPTION_IF_NULL(rank_tag1); + MS_EXCEPTION_IF_NULL(rank_tag2); + auto rank1_value = GetValue(rank_tag1); + auto rank2_value = GetValue(rank_tag2); + if (rank1_value == rank2_value) { + auto sr_tag1 = prim1->GetAttr(SR_TAG); + auto sr_tag2 = prim2->GetAttr(SR_TAG); + MS_EXCEPTION_IF_NULL(sr_tag1); + MS_EXCEPTION_IF_NULL(sr_tag2); + auto sr1_value = GetValue(sr_tag1); + auto sr2_value = GetValue(sr_tag2); + return sr1_value < sr2_value; + } + return rank1_value < rank2_value; + } + return micro1_value < micro2_value; +} + +void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager, + const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(post_node); + auto post_cnode = post_node->cast(); + MS_EXCEPTION_IF_NULL(post_cnode); + std::vector depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node}; + auto depend_node = root->NewCNode(depend_input); + manager->SetEdge(post_node, 1, depend_node); +} + +void ReorderForForward(const std::vector &forward_start, const std::vector &forward_end, + const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(g_device_manager); + MS_EXCEPTION_IF_NULL(root); + auto manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto stage_num = g_device_manager->stage_num(); + auto stage_id = g_device_manager->stage_id(); + for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) { + auto prior_node = forward_end[i - 1]; + auto post_node = forward_start[i]; + InsertDepend(prior_node, post_node, manager, root); + } +} + +void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair, + const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, + const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(g_device_manager); + MS_EXCEPTION_IF_NULL(root); + auto manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto stage_num = g_device_manager->stage_num(); + auto stage_id = g_device_manager->stage_id(); + for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) { + auto prior_node1 = forward_end_before_pair.second[i]; + auto post_node1 = backward_start_pair.first[i - stage_num + stage_id + 1]; + InsertDepend(prior_node1, post_node1, manager, root); + auto prior_node2 = backward_end_pair.second[i - stage_num + stage_id]; + auto post_node2 = forward_start_pair.first[i]; + InsertDepend(prior_node2, post_node2, manager, root); + } + for (size_t i = (stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) { + if (!IsLastStage()) { + auto prior_node3 = backward_start_pair.second[i - stage_num + stage_id]; + auto post_node3 = forward_end_pair.first[i - 1]; + InsertDepend(prior_node3, post_node3, manager, root); + auto prior_node4 = forward_end_pair.second[i - 1]; + auto post_node4 = backward_end_pair.first[i - stage_num + stage_id]; + InsertDepend(prior_node4, post_node4, manager, root); + } + } + for (size_t j = (backward_start_pair.first.size() - stage_num + stage_id + 1); j < backward_start_pair.first.size(); + ++j) { + auto prior_node5 = backward_end_pair.second[j - 1]; + auto post_node5 = backward_start_pair.first[j]; + InsertDepend(prior_node5, post_node5, manager, root); + } + if (!IsLastStage()) { + auto prior_node6 = forward_end_before_pair.second[stage_num - 1 - stage_id]; + auto post_node6 = backward_start_pair.first[0]; + InsertDepend(prior_node6, post_node6, manager, root); + } +} + +void ReorderForParams(const std::vector &backward_params, const std::vector &forward_params, + const std::vector &allreduce_params, const PipelinePair &forward_params_pair, + const PipelinePair &backward_params_pair, const std::vector &backward_end, + const PipelinePair &forward_start_pair, const FuncGraphPtr &root) { + auto manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (!forward_params.empty()) { + auto prior_node = forward_params_pair.second[0]; + auto post_node = forward_start_pair.first[0]; + InsertDepend(prior_node, post_node, manager, root); + } + if (!backward_params.empty()) { + if (!allreduce_params.empty()) { + for (auto &node : allreduce_params) { + auto post_node1 = backward_params_pair.first[0]; + InsertDepend(node, post_node1, manager, root); + } + } + auto prior_node2 = backward_end.back(); + auto post_node2 = backward_params[0]; + InsertDepend(prior_node2, post_node2, manager, root); + } +} + +int64_t GetMicroBatch(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto micro_value = cnode->GetPrimalAttr(MICRO); + MS_EXCEPTION_IF_NULL(micro_value); + return GetValue(micro_value); +} + +PipelinePair Deduplicate(const std::vector &node_vector, const FuncGraphPtr &root, int64_t micro_max) { + std::vector temp_vec; + std::vector out_vec_begin; + std::vector out_vec_end; + auto manager = root->manager(); + for (int64_t i = 0; i <= micro_max; ++i) { + temp_vec.clear(); + for (auto &node : node_vector) { + auto node_micro = GetMicroBatch(node); + if (node_micro == i) { + temp_vec.push_back(node); + } + } + if (temp_vec.size() <= 1) { + MS_LOG(INFO) << "No Duplicate MicroBatch."; + continue; + } + std::sort(temp_vec.begin(), temp_vec.end(), CompFunc); + for (size_t j = 0; j < temp_vec.size() - 1; ++j) { + auto prior_node = temp_vec[j]; + auto post_node = temp_vec[j + 1]; + InsertDepend(prior_node, post_node, manager, root); + } + if (!temp_vec.empty()) { + out_vec_begin.push_back(temp_vec.front()); + out_vec_end.push_back(temp_vec.back()); + } + } + if (out_vec_begin.empty()) { + return std::make_pair(node_vector, node_vector); + } + return std::make_pair(out_vec_begin, out_vec_end); +} + +void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value) { + auto node_users = (*node_users_map)[node]; + for (auto &node_pair : node_users) { + auto user_node = node_pair.first->cast(); + if (user_node->HasPrimalAttr(MICRO)) { + continue; + } + user_node->AddPrimalAttr(MICRO, value); + BroadCastMicroBatch(user_node, node_users_map, value); + } +} + +AnfNodePtr GetPreNode(const AnfNodePtr &node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + return GetPreNode(cnode->input(1)); + } + return cnode; +} + +void LastStageEndNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { + if (!IsLastStage()) { + return; + } + auto node_users_map = manager->node_users(); + for (auto &node : all_nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!cnode->HasPrimalAttr(MICRO)) { + continue; + } + auto prim = GetCNodePrimitive(node); + if (prim && prim->HasAttr(PIPELINE_END)) { + for (auto &temp_node : cnode->inputs()) { + if (!temp_node->isa()) { + continue; + } + auto temp_cnode = temp_node->cast(); + auto temp_prim = GetCNodePrimitive(temp_node); + if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) { + continue; + } + auto end_node = GetPreNode(temp_node); + auto end_cnode = end_node->cast(); + MS_EXCEPTION_IF_NULL(end_cnode); + auto end_prim = GetCNodePrimitive(end_node); + OperatorAttrs attrs_; + auto op = CreatOpInstance(attrs_, end_prim->name(), ""); + auto value_node = NewValueNode(op); + auto new_prim = GetValueNode(value_node)->cast(); + new_prim->SetAttrs(end_prim->attrs()); + manager->SetEdge(end_node, 0, value_node); + end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO)); + } + } + } +} + +void ParameterStartNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { + auto node_users_map = manager->node_users(); + for (auto &node : all_nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!cnode->HasPrimalAttr(MICRO)) { + continue; + } + auto micro = cnode->GetPrimalAttr(MICRO); + auto prim = GetCNodePrimitive(node); + if (prim && prim->HasAttr(PARAMETER_START)) { + OperatorAttrs attrs_; + auto op = CreatOpInstance(attrs_, prim->name(), ""); + auto value_node = NewValueNode(op); + auto new_prim = GetValueNode(value_node)->cast(); + new_prim->SetAttrs(prim->attrs()); + manager->SetEdge(cnode, 0, value_node); + cnode->AddPrimalAttr(PARAMETER_START, micro); + } + } +} + +void HandleMicroBatch(const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { + auto node_users_map = manager->node_users(); + for (auto &node : all_nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!cnode->HasPrimalAttr(MICRO)) { + continue; + } + auto micro = cnode->GetPrimalAttr(MICRO); + MS_EXCEPTION_IF_NULL(micro); + BroadCastMicroBatch(cnode, &node_users_map, micro); + } +} + +void GetBorderNode(std::vector *forward_start, std::vector *forward_end, + std::vector *backward_start, std::vector *backward_end, + std::vector *forward_params, std::vector *backward_params, + std::vector *allreduce_params, const FuncGraphPtr &root) { + std::list name_list = {}; + auto stage_id = g_device_manager->stage_id(); + for (auto &node : root->nodes()) { + if (!node->isa()) { + continue; + } + if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimZerosLike)) { + continue; + } + auto prim = GetCNodePrimitive(node); + auto cnode = node->cast(); + if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) { + auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName); + if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) { + continue; + } + name_list.push_back(forward_node_name); + if (cnode->HasPrimalAttr(PIPELINE_END)) { + backward_start->push_back(node); + } + if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) { + backward_end->push_back(node); + } + if (cnode->HasPrimalAttr(PARAMETER_START)) { + backward_end->push_back(node); + } + if (cnode->HasPrimalAttr(PIPELINE_PARAM)) { + backward_params->push_back(node); + } + if (prim->HasAttr(PARAMETER_MICRO)) { + allreduce_params->push_back(node); + } + } else { + if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) { + if (stage_id != 0 && IsPrimitiveCNode(node, prim::kPrimStridedSlice)) { + continue; + } + forward_start->push_back(node); + } + if (cnode->HasPrimalAttr(PIPELINE_END)) { + forward_end->push_back(node); + } + if (cnode->HasPrimalAttr(PIPELINE_PARAM)) { + forward_params->push_back(node); + } + } + } + std::sort((*backward_start).begin(), (*backward_start).end(), CompFunc); + std::sort((*backward_end).begin(), (*backward_end).end(), CompFunc); + std::sort((*forward_start).begin(), (*forward_start).end(), CompFunc); + std::sort((*forward_end).begin(), (*forward_end).end(), CompFunc); + std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc); + std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc); +} + +void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair, + const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, + size_t micro_size) { + micro_size = micro_size + 1; + if (forward_start_pair.first.size() != micro_size) { + MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size() + << "is not equal to micro size:" << micro_size; + } + if (forward_end_pair.first.size() != micro_size) { + MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size() + << "is not equal to micro size:" << micro_size; + } + if (backward_start_pair.first.size() != micro_size) { + MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size() + << "is not equal to micro size:" << micro_size; + } + if (backward_end_pair.first.size() != micro_size) { + MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size() + << "is not equal to micro size:" << micro_size; + } +} + +void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + std::vector forward_start; + std::vector forward_end; + std::vector forward_params; + std::vector backward_start; + std::vector backward_end; + std::vector backward_params; + std::vector allreduce_params; + GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params, + &allreduce_params, root); + auto forward_end_cnode = forward_end.back()->cast(); + auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO); + MS_EXCEPTION_IF_NULL(micro_size); + auto micro_max = GetValue(micro_size); + auto backward_start_pair = Deduplicate(backward_start, root, micro_max); + auto backward_end_pair = Deduplicate(backward_end, root, micro_max); + auto forward_start_pair = Deduplicate(forward_start, root, micro_max); + auto forward_end_pair = Deduplicate(forward_end, root, micro_max); + auto forward_params_pair = Deduplicate(forward_params, root, micro_max); + auto backward_params_pair = Deduplicate(backward_params, root, micro_max); + CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, LongToSize(micro_max)); + PipelinePair forward_end_before_pair; + if (!IsLastStage()) { + for (auto &node : forward_end_pair.first) { + auto cnode = node->cast(); + forward_end_before_pair.first.push_back(cnode->input(1)); + } + for (auto &node : forward_end_pair.second) { + auto cnode = node->cast(); + forward_end_before_pair.second.push_back(cnode->input(1)); + } + } else { + forward_end_before_pair = forward_end_pair; + } + ReorderForForward(forward_start_pair.first, forward_end_pair.second, root); + ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, + forward_end_before_pair, root); + ReorderForParams(backward_params, forward_params, allreduce_params, forward_params_pair, backward_params_pair, + backward_end, forward_start_pair, root); +} + +void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + std::vector forward_end; + std::vector forward_start; + std::vector forward_params; + for (auto &node : root->nodes()) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) { + forward_start.push_back(node); + } + if (cnode->HasPrimalAttr(PIPELINE_END)) { + forward_end.push_back(node); + } + if (cnode->HasPrimalAttr(PIPELINE_PARAM)) { + forward_params.push_back(node); + } + } + std::sort(forward_start.begin(), forward_start.end(), CompFunc); + std::sort(forward_end.begin(), forward_end.end(), CompFunc); + std::sort(forward_params.begin(), forward_params.end(), CompFunc); + auto forward_start_pair = Deduplicate(forward_start, root, 0); + auto forward_end_pair = Deduplicate(forward_end, root, 0); + auto forward_params_pair = Deduplicate(forward_params, root, 0); + if (!forward_end.empty() && !forward_params.empty()) { + InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root); + } + if (!forward_start.empty() && !forward_params.empty()) { + InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root); + } +} + +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h new file mode 100755 index 00000000000..e7cad3f9323 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/manager.h" + +namespace mindspore { +namespace parallel { +using PipelinePair = std::pair, std::vector>; +AnfNodePtr FindAccuGrad(const CNodePtr &cnode); +bool IsLastStage(); +void InsertVirtualAssignAdd(const std::pair &node_user, const FuncGraphManagerPtr &manager, + const AnfNodePtr &accu_parameter); +void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m); +AnfNodePtr FindGradAccuParameter(const std::vector ¶meters, const std::string &name); +void HandleReceiveParam(const FuncGraphPtr &root, const std::vector &all_nodes); +void AddVirtualAssignAdd(const FuncGraphPtr &root); +bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2); +void ReorderForForward(const std::vector &forward_start, const std::vector &forward_end, + const FuncGraphPtr &root); +void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair, + const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, + const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root); +void ReorderForParams(const std::vector &backward_params, const std::vector &forward_params, + const std::vector &allreduce_params, const PipelinePair &forward_params_pair, + const PipelinePair &backward_params_pair, const std::vector &backward_end, + const PipelinePair &forward_start_pair, const FuncGraphPtr &root); +int64_t GetMicroBatch(const AnfNodePtr &node); +void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager, + const FuncGraphPtr &root); +PipelinePair Deduplicate(const std::vector &node_vector, const FuncGraphPtr &root, int64_t micro_max); +void GetBorderNode(std::vector *forward_start, std::vector *forward_end, + std::vector *backward_start, std::vector *backward_end, + std::vector *forward_params, std::vector *backward_params, + std::vector *allreduce_params, const FuncGraphPtr &root); +void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); +void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); +void HandleMicroBatch(const std::vector &all_nodes, const FuncGraphManagerPtr &manager); +void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value); +AnfNodePtr GetPreNode(const AnfNodePtr &node); +void LastStageEndNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager); +void SetStridedSliceStrategy(const AnfNodePtr &node); +void ParameterStartNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager); +void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair, + const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, size_t micro_size); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index b1138de97f4..65bad6d7490 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -342,11 +342,12 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { MS_EXCEPTION_IF_NULL(comm_node); MS_EXCEPTION_IF_NULL(param_node); + ParameterPtr param; if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { - MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; - return; + param = param_node->user_data(PIPELINE_PARAM)->cast(); + } else { + param = param_node->cast(); } - auto param = param_node->cast(); MS_EXCEPTION_IF_NULL(param); auto prim = GetValueNode(comm_node->input(0)); MS_EXCEPTION_IF_NULL(prim); @@ -372,6 +373,22 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) { prim->SetAttrs(attrs); } +void AddCommOpParamFlag(const CNodePtr &comm_node) { + MS_EXCEPTION_IF_NULL(comm_node); + auto graph = comm_node->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_users = manager->node_users()[comm_node->input(1)]; + for (auto &node_user : node_users) { + if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) { + auto prim = GetCNodePrimitive(comm_node); + prim->AddAttr(PARAMETER_MICRO, MakeValue(0)); + return; + } + } +} + Operator CreateAllGatherOp(const std::string &group) { OperatorName operator_name = ALL_GATHER; ValuePtr attr0_value = MakeValue(group); // group @@ -438,6 +455,7 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { OperatorVector op_for_weight; bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); ValuePtr attr0_value = MakeValue(group_name); ValuePtr attr1_value = MakeValue(SizeToLong(dev_num)); @@ -459,6 +477,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value); operator_attrs.push_back(attr3); MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror"; + } else if (split_stage_num > 1) { + operator_name = MIRROR_MICRO_STEP_OPERATOR; } else { operator_name = MIRROR_OPERATOR; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index fc4ef7e201c..cd05a9670cb 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -294,6 +294,7 @@ Operator CreateAllGatherOp(const std::string &group); Operator CreateMiniStepAllGatherOp(const std::string &group); void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); void AddCommOpMeanFlag(const CNodePtr &comm_node); +void AddCommOpParamFlag(const CNodePtr &comm_node); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 0e9b8518524..a343af6889a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -168,7 +168,7 @@ constexpr char CLONED_INDEX[] = "cloned_index"; constexpr char BE_CLONED_INDEX[] = "be_cloned_index"; constexpr char GROUP_RANKS[] = "group_ranks"; constexpr char IS_IN_FORWARD[] = "is_in_forward"; -constexpr char DTYPE[] = "DType"; +constexpr char DTYPE[] = "dtype"; constexpr char DEV_NUM[] = "dev_num"; constexpr char MEAN_FLAG[] = "mean_flag"; constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step"; @@ -348,6 +348,23 @@ constexpr char UNIQUE[] = "Unique"; constexpr char GATHERND[] = "GatherNd"; constexpr char SCATTER_UPDATE[] = "ScatterUpdate"; +// pipeline +constexpr char MICRO[] = "micro"; +constexpr char DEST_RANK[] = "dest_rank"; +constexpr char SRC_RANK[] = "src_rank"; +constexpr char PIPELINE_PARAM[] = "pipeline_param"; +constexpr char PIPELINE_END[] = "pipeline_end"; +constexpr char PIPELINE_BEGIN[] = "pipeline_begin"; +constexpr char MAIN_GRAPH[] = "main_graph"; +constexpr char SR_TAG[] = "sr_tag"; +constexpr char GROUP_BACK[] = "group_back"; +constexpr char MIRROR_MICRO_STEP_OPERATOR[] = "_MirrorMicroStepOperator"; +constexpr char PARAMETER_MICRO[] = "parameter_micro"; +constexpr char VIRTUAL_ASSIGN_ADD[] = "_VirtualAssignAdd"; +constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad"; +constexpr char ACCU_GRAD[] = "accu_grad"; +constexpr char PARAMETER_START[] = "parameter_start"; + // Parallel don't care constexpr char STRING_EQUAL[] = "string_equal"; constexpr char MAKE_TUPLE[] = "MakeTuple"; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 72a53e51ab8..e7d5cd75279 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -29,6 +29,7 @@ #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/node_check.h" #include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/graph_util/pipeline_split_utils.h" #include "ir/anf.h" #include "base/core_ops.h" #include "utils/comm_manager.h" @@ -52,30 +53,74 @@ static bool IsInWhiteList(const CNodePtr &cnode) { return false; } -static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) { - if (curr_depth > MAX_RECURSIVE_DEPTH) { - MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: " - << MAX_RECURSIVE_DEPTH; +void PipelineTransformer::MainGraph() { + if (!root_->has_flag(TRAINING)) { + main_graph_ = root_; return; } - const auto &node_users = manager->node_users()[node]; - for (auto &user_pair : node_users) { - auto user_node = user_pair.first; - if (!user_node->grad()) { - user_node->set_grad(true); - SetGradTag(user_node, manager, ++curr_depth); + for (auto &fg : manager_->func_graphs()) { + for (auto &node : fg->nodes()) { + if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) { + main_graph_ = fg; + main_graph_->set_flag(MAIN_GRAPH, true); + virtual_dataset_ = node; + return; + } + } + } + MS_LOG(EXCEPTION) << "Can't find main graph, possible reason is can't find virtual dataset."; +} + +ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size) { + if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) { + MS_LOG(EXCEPTION) << "Can't find MicroBatch information."; + } + auto cnode = node->cast(); + auto value = GetValueNode(cnode->input(2)); + MS_EXCEPTION_IF_NULL(value); + auto tuple = GetValue>(value); + auto input_shape = GetNodeShape(cnode->input(1)).at(0); + int64_t micro = tuple.at(0) * micro_size / input_shape.at(0); + cnode->AddPrimalAttr(MICRO, MakeValue(micro)); + cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro)); + return MakeValue(micro); +} + +void PipelineTransformer::LabelMicroBatch() { + MS_EXCEPTION_IF_NULL(main_graph_); + MS_EXCEPTION_IF_NULL(virtual_dataset_); + auto node_user_map = manager_->node_users(); + auto node_users = node_user_map[virtual_dataset_]; + for (auto &node_user : node_users) { + if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { + auto data_users = manager_->node_users()[node_user.first]; + auto micro_size = int64_t(data_users.size()); + micro_size_ = micro_size; + MS_LOG(INFO) << "Micro Size is: " << micro_size; + for (auto &data_user : data_users) { + auto micro = SetMicroBatch(data_user.first, micro_size); + SetStridedSliceStrategy(data_user.first); + auto cnode = data_user.first->cast(); + BroadCastMicroBatch(cnode, &node_user_map, micro); + } } } } -void PipelineTransformer::LabelRequiredGradCNode() { - auto parameters = root_->parameters(); - for (auto parameter : parameters) { - if (!ParameterRequireGrad(parameter)) { - continue; - } - SetGradTag(parameter, manager_, 0); +void PipelineTransformer::CreateForwardGroup() { + std::vector rank_list; + auto rank_id = g_device_manager->global_rank(); + auto stage_id = g_device_manager->stage_id(); + auto stage_num = g_device_manager->stage_num(); + for (int64_t i = 0; i < stage_num; ++i) { + rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id)); } + auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list); + auto g = g_device_manager->CreateGroup(rank_list); + auto g_back_name = g.name() + BACKWARD; + auto g_back = g_device_manager->CreateGroup(g_back_name, dev_list); + group_.push_back(g.name()); + group_.push_back(g_back.name()); } void PipelineTransformer::Coloring() { @@ -84,7 +129,7 @@ void PipelineTransformer::Coloring() { while (need_coloring) { need_coloring = false; for (auto &fg : manager_->func_graphs()) { - if (fg == root_) { + if (fg == root_ && root_->has_flag(TRAINING)) { continue; } auto value_nodes = fg->value_nodes(); @@ -94,16 +139,15 @@ void PipelineTransformer::Coloring() { continue; } auto graph = GetValueNode(node); - auto need_grad = graph->get_return()->grad(); + if (graph->stage() == -1) { + continue; + } + stage_set.insert(graph->stage()); auto node_users = manager_->node_users()[node]; for (auto &user_pair : node_users) { auto user_node = user_pair.first->cast(); user_node->set_stage(graph->stage()); - user_node->set_grad(need_grad); auto user_node_graph = user_node->func_graph(); - if (graph->stage() != -1) { - stage_set.insert(graph->stage()); - } if (graph->stage() == stage_ && user_node_graph->stage() == -1) { user_node_graph->set_stage(graph->stage()); need_coloring = true; @@ -117,22 +161,37 @@ void PipelineTransformer::Coloring() { if (SizeToLong(stage_set.size()) != stage_num) { MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size(); } - return; } void PipelineTransformer::BroadCastColoring() { - for (auto &fg : manager_->func_graphs()) { - if (fg == root_ || fg->stage() == -1) { - continue; + auto need_coloring = true; + while (need_coloring) { + need_coloring = false; + auto all_nodes = main_graph_->nodes(); + auto node_users = manager_->node_users(); + for (auto &node : all_nodes) { + if (!node->isa() || node->stage() == -1) { + continue; + } + auto stage = node->stage(); + for (auto &user_pair : node_users[node]) { + auto user_node = user_pair.first->cast(); + auto user_node_stage = user_node->stage(); + if (stage > user_node_stage) { + user_node->set_stage(stage); + need_coloring = true; + } + } } - DoBroadCast(fg); - SetNoStageNode(fg); } } bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); auto prim = GetValueNode(cnode->input(0)); + if (!prim) { + return false; + } if (IsInWhiteList(cnode)) { return false; } @@ -148,6 +207,9 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { if (!IsPipelineCareNode(cnode)) { MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node."; } + if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) { + SetVirtualDatasetStrategy(cnode); + } auto shape_list = ExtractShape(cnode); if (shape_list.empty()) { MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape."; @@ -155,7 +217,7 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { auto prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); if (prim->name() == RESHAPE) { - MS_LOG(EXCEPTION) << "Reshape op can't be a border."; + MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << cnode->DebugString(); } auto attrs = prim->attrs(); auto op_info = OperatorInstance(prim, attrs, shape_list); @@ -190,93 +252,87 @@ std::pair PipelineTransformer::GetOpInfo(const A MS_EXCEPTION_IF_NULL(cnode); // Handle Cast and TupleGetitem situation size_t tensor_info_index = 0; - if (IsPrimitiveCNode(cnode, prim::kPrimCast)) { - cnode = cnode->input(1)->cast(); - } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { - tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode)); - cnode = cnode->input(1)->cast(); + OperatorInfoPtr op_info; + if (IsPrimitiveCNode(node, prim::kPrimReceive)) { + op_info = node->user_data(); + } else { + if (IsPrimitiveCNode(node, prim::kPrimCast)) { + cnode = cnode->input(1)->cast(); + } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { + tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode)); + cnode = cnode->input(1)->cast(); + } + // Create OperatorInfo to get slice_shape for send/recv + MS_EXCEPTION_IF_NULL(cnode); + op_info = CreateOpInfo(cnode); } - // Create OperatorInfo to get slice_shape for send/recv - MS_EXCEPTION_IF_NULL(cnode); - auto op_info = CreateOpInfo(cnode); MS_EXCEPTION_IF_NULL(op_info); auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index]; return std::make_pair(op_info, std::make_shared(tensor_info)); } -CNodePtr PipelineTransformer::HandleMonadLoad(const AnfNodePtr &node) { +std::pair PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto &node_users = manager_->node_users()[node]; + auto node_users_map = manager_->node_users(); + auto node_users = node_users_map[node]; + for (auto &node_user : node_users) { + auto load = node_user.first->cast(); + if (IsPrimitiveCNode(load, prim::kPrimLoad)) { + node_users = node_users_map[load]; + break; + } + } for (auto &user_pair : node_users) { auto user_node = user_pair.first->cast(); MS_EXCEPTION_IF_NULL(user_node); - if (IsPipelineCareNode(user_node)) { - return user_node; + auto user_node_graph = user_node->func_graph(); + MS_EXCEPTION_IF_NULL(user_node_graph); + if (user_node_graph->stage() == -1) { + continue; } - } - return nullptr; -} - -std::pair PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto &node_users = manager_->node_users()[node]; - for (auto &user_pair : node_users) { - auto care_node = user_pair.first; - auto care_cnode = care_node->cast(); - if (IsPrimitiveCNode(care_node, prim::kPrimLoad)) { - care_cnode = HandleMonadLoad(care_node); - if (!care_cnode) { - continue; + auto care_node = user_node; + auto index = user_pair.second; + if (IsValueNode(user_node->input(0))) { + auto graph = GetValueNode(user_node->input(0)); + auto temp_params = graph->parameters(); + if (temp_params.size() < IntToSize(user_pair.second)) { + MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString() + << "'s range."; } - } else { - if (!IsPipelineCareNode(care_cnode)) { - continue; + auto temp_param = temp_params[user_pair.second - 1]; + auto temp_users = node_users_map[temp_param]; + for (auto &temp_user : temp_users) { + auto load_temp = temp_user.first->cast(); + if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) { + temp_users = node_users_map[load_temp]; + break; + } + } + for (auto &temp_pair : temp_users) { + auto temp_cnode = temp_pair.first->cast(); + if (!IsPipelineCareNode(temp_cnode)) { + continue; + } + care_node = temp_cnode; + index = temp_pair.second; + break; } } - MS_EXCEPTION_IF_NULL(care_cnode); - auto op_info = CreateOpInfo(care_cnode); + if (!IsPipelineCareNode(care_node)) { + continue; + } + auto op_info = CreateOpInfo(care_node); MS_EXCEPTION_IF_NULL(op_info); - auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1]; - return std::make_pair(nullptr, std::make_shared(tensor_info)); + auto tensor_info = op_info->inputs_tensor_info()[IntToSize(index) - 1]; + return std::make_pair(op_info, std::make_shared(tensor_info)); } return std::make_pair(nullptr, nullptr); } -void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { - auto need_coloring = true; - while (need_coloring) { - need_coloring = false; - auto all_nodes = func->nodes(); - auto &node_users = manager_->node_users(); - for (auto &node : all_nodes) { - if (node->isa() || node->stage() == -1) { - continue; - } - auto stage = node->stage(); - for (auto &user_pair : node_users[node]) { - auto user_node = user_pair.first->cast(); - auto user_node_stage = user_node->stage(); - if (IsValueNode(user_node->input(0)) && stage > user_node_stage) { - user_node->set_stage(stage); - need_coloring = true; - } - } - } - } -} - -void PipelineTransformer::SetNoStageNode(const FuncGraphPtr &func) { - auto all_nodes = func->nodes(); - for (auto &node : all_nodes) { - if (!node->isa() || node->stage() != -1) { - continue; - } - node->set_stage(0); - } -} - -void PipelineTransformer::HandleSharedParameter() { +std::vector PipelineTransformer::HandleSharedParameter() { auto parameters = root_->parameters(); + std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; + std::vector recvs = {}; for (auto ¶meter : parameters) { auto parameter_stage = parameter_color_map[parameter]; if (parameter_stage.size() <= 1) { @@ -285,37 +341,41 @@ void PipelineTransformer::HandleSharedParameter() { auto users = manager_->node_users()[parameter]; for (auto &user : users) { auto node = user.first; + auto cnode = node->cast(); auto graph = node->func_graph(); - if (graph != root_ && graph->stage() == -1) { - MS_LOG(EXCEPTION) << "Don't support this situation."; + if (IsValueNode(cnode->input(0))) { + graph = GetValueNode(cnode->input(0)); } - if (graph == root_ || graph->stage() != stage_) { + if (graph == root_ || graph->stage() == -1 || !parameter_stage.count(stage_)) { continue; } + auto micro = cnode->GetPrimalAttr(MICRO); + if (!micro) { + MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch"; + micro = MakeValue(int64_t(0)); + } + auto user_stage = node->stage(); if (stage_ == *parameter_stage.begin()) { - std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; - for (auto &stage : parameter_stage) { - if (stage == stage_) { - continue; - } else { - auto send_out = InsertSend(graph, parameter, stage, stage_); - make_tuple_input.push_back(send_out.depend); - } + if (graph->stage() == stage_) { + continue; } - auto make_tuple = graph->NewCNode(make_tuple_input); - OperatorAttrs depend_attrs; - auto depend_op = CreatOpInstance(depend_attrs, DEPEND, ""); - std::vector depend_input = {NewValueNode(depend_op), parameter, make_tuple}; - auto depend = graph->NewCNode(depend_input); - depend->set_abstract(parameter->abstract()); - manager_->SetEdge(node, user.second, depend); - break; + if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) { + continue; + } + auto send_out = InsertSend(main_graph_, parameter, user_stage, stage_, micro); + make_tuple_input.push_back(send_out.depend); } else { - (void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin()); - break; + auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK); + if (receive) { + manager_->SetEdge(node, user.second, receive); + } else { + auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro); + recvs.push_back(recv); + } } } } + return make_tuple_input; } void PipelineTransformer::ParameterColoring() { @@ -324,14 +384,24 @@ void PipelineTransformer::ParameterColoring() { auto users = manager_->node_users()[parameter]; std::set parameter_stage; for (auto &user : users) { - auto node = user.first; + auto node = user.first->cast(); auto graph = node->func_graph(); + if (IsValueNode(node->input(0))) { + graph = GetValueNode(node->input(0)); + } if (graph != root_ && graph->stage() != -1) { parameter_stage.insert(graph->stage()); parameter->set_stage(graph->stage()); } } - if (*parameter_stage.begin() == stage_ && !virtual_param_) { + auto param_info = parameter->cast()->param_info(); + if (!param_info) { + parameter_color_map[parameter] = parameter_stage; + continue; + } + MS_EXCEPTION_IF_NULL(param_info); + auto requires_grad = param_info->requires_grad(); + if (*parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) { virtual_param_ = parameter; } parameter_color_map[parameter] = parameter_stage; @@ -343,8 +413,8 @@ static std::pair GetShapeType(const AnfNodePtr &node, con auto cnode = node->cast(); if (cnode != nullptr && IsValueNode(cnode->input(0))) { auto graph = GetValueNode(cnode->input(0)); - auto graph_return = graph->get_return(); - type = graph_return->Type(); + auto graph_output = graph->output(); + type = graph_output->Type(); } else { type = node->Type(); } @@ -359,40 +429,38 @@ static std::pair GetShapeType(const AnfNodePtr &node, con return std::make_pair(shape_list, dtype); } -AnfNodePtr PipelineTransformer::HandleMonadDepend(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - auto cnode = node->cast(); - return HandleMonadDepend(cnode->input(1)); - } - return node; -} - AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (IsValueNode(cnode->input(0))) { auto graph = GetValueNode(cnode->input(0)); - auto output = HandleMonadDepend(graph->output()); + auto output = graph->output(); MS_EXCEPTION_IF_NULL(output); if (output->isa()) { - return output; + auto parameters = graph->parameters(); + auto pos_iter = std::find(parameters.begin(), parameters.end(), output); + auto pos = std::distance(parameters.begin(), pos_iter); + return FindPipelineCareNode(cnode->input(pos + 1)); } cnode = output->cast(); MS_EXCEPTION_IF_NULL(cnode); } + if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) { + return FindPipelineCareNode(cnode->input(1)); + } if (IsInWhiteList(cnode)) { return cnode->cast(); } if (!IsPipelineCareNode(cnode)) { - MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."; + MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border." + << " border node: " << cnode->DebugString(); } return cnode->cast(); } SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, - int64_t user_node_stage, int64_t node_stage) { + int64_t user_node_stage, int64_t node_stage, const ValuePtr &value) { auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; int64_t send_tag; if (send_tag_map.find(dest_rank) != send_tag_map.end()) { @@ -402,17 +470,25 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod send_tag = 0; send_tag_map[dest_rank] = 0; } - Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); - Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); - OperatorAttrs attrs = {attr_tag, attr_rank}; - auto send_op = CreatOpInstance(attrs, SEND, "send"); + Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag)); + Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage)); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0])); + Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); + OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back}; + auto send_op = CreatOpInstance(attrs, SEND, SEND); auto send_node = NewValueNode(send_op); auto prim = GetValueNode(send_node); std::pair op_info_pair; + AnfNodePtr care_node; if (parameter->isa()) { op_info_pair = GetParameterPair(parameter); } else { - auto care_node = FindPipelineCareNode(parameter); + if (IsPrimitiveCNode(parameter, prim::kPrimCast)) { + auto parameter_cnode = parameter->cast(); + care_node = FindPipelineCareNode(parameter_cnode->input(1)); + } else { + care_node = FindPipelineCareNode(parameter); + } if (care_node->isa()) { op_info_pair = GetParameterPair(care_node); } else { @@ -423,14 +499,20 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod MS_EXCEPTION_IF_NULL(tensor_info); auto slice_shape = tensor_info->slice_shape(); auto shape_type_pair = GetShapeType(parameter, slice_shape); - prim->set_attr("shape", shape_type_pair.first); - prim->set_attr("dtype", shape_type_pair.second); + prim->set_attr(SHAPE, shape_type_pair.first); + prim->set_attr(DTYPE, shape_type_pair.second); std::vector send_input = {send_node, parameter}; - auto send = graph->NewCNode(send_input); + auto send = main_graph_->NewCNode(send_input); + if (!parameter->isa() && care_node != nullptr && !care_node->isa()) { + send->AddPrimalAttr(PIPELINE_END, value); + } else { + send->AddPrimalAttr(PIPELINE_PARAM, value); + send->AddPrimalAttr(MICRO, value); + } OperatorAttrs depend_attrs; - auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend"); + auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); std::vector depend_input = {NewValueNode(depend_op), parameter, send}; - auto depend = graph->NewCNode(depend_input); + auto depend = main_graph_->NewCNode(depend_input); auto abstract = parameter->abstract(); depend->set_abstract(abstract); SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; @@ -439,7 +521,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, int64_t user_node_stage, - int64_t node_stage) { + int64_t node_stage, const ValuePtr &value) { auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; int64_t recv_tag; if (recv_tag_map.find(src_rank) != recv_tag_map.end()) { @@ -449,9 +531,10 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A recv_tag = 0; recv_tag_map[src_rank] = 0; } - Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); - Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank)); + Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag)); + Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage)); std::pair op_info_pair; + bool is_param = true; if (node->isa()) { op_info_pair = GetParameterPair(node); } else { @@ -460,28 +543,34 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A op_info_pair = GetParameterPair(care_node); } else { op_info_pair = GetOpInfo(care_node); + is_param = false; } } auto tensor_info = op_info_pair.second; MS_EXCEPTION_IF_NULL(tensor_info); - auto slice_shape = tensor_info->slice_shape(); + auto tensor_layout = tensor_info->tensor_layout(); + Shape slice_shape = tensor_info->slice_shape(); auto shape_type_pair = GetShapeType(node, slice_shape); - Attr attr_shape = std::make_pair("shape", shape_type_pair.first); - Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); - OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; - auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv"); + Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first); + Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0])); + Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1])); + OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back}; + auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE); std::vector recv_input; if (node->isa()) { recv_input = {NewValueNode(recv_op), node}; } else { - if (node->grad()) { - recv_input = {NewValueNode(recv_op), virtual_param_}; - } else { - auto param = root_->parameters()[0]; - recv_input = {NewValueNode(recv_op), param}; - } + recv_input = {NewValueNode(recv_op), virtual_param_}; } auto recv = graph->NewCNode(recv_input); + if (is_param) { + recv->set_user_data(PIPELINE_PARAM, node); + recv->AddPrimalAttr(PIPELINE_PARAM, value); + } else { + recv->AddPrimalAttr(PIPELINE_BEGIN, value); + } + recv->AddPrimalAttr(MICRO, value); auto node_abstract = node->abstract(); if (node->isa()) { auto cnode = node->cast(); @@ -494,65 +583,53 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A } MS_EXCEPTION_IF_NULL(node_abstract); recv->set_abstract(node_abstract); - if (op_info_pair.first != nullptr) { - recv->set_user_data(std::make_shared(tensor_info->tensor_layout())); - recv->set_user_data(op_info_pair.first); + if (node->isa()) { + BaseShapePtr parallel_shape = std::make_shared(slice_shape); + auto abstract_clone = node->abstract()->Clone(); + MS_EXCEPTION_IF_NULL(abstract_clone); + abstract_clone->set_shape(parallel_shape); + node->set_abstract(abstract_clone); + node->set_user_data(std::make_shared(tensor_layout)); } + recv->set_user_data(std::make_shared(tensor_layout)); + recv->set_user_data(op_info_pair.first); + manager_->SetEdge(use_node, index, recv); return recv; } -bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage, - const std::vector &out_input) { - auto node_users = manager_->node_users()[node]; - auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_; - for (auto &depend : out_input) { - if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) { +AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector &out_input, + const std::string &tag) { + for (auto &input : out_input) { + auto cnode = input->cast(); + if (!cnode) { continue; } - auto cnode = depend->cast(); + if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) { + cnode = cnode->input(2)->cast(); + } if (cnode->input(1) == node) { - auto send_cnode = cnode->input(2)->cast(); - auto prim = GetValueNode(send_cnode->input(0)); - auto dest_rank_send = GetValue(prim->GetAttr("dest_rank")); - if (dest_rank_send == dest_rank) { - return true; + auto prim = GetValueNode(cnode->input(0)); + auto dest_rank_send = GetValue(prim->GetAttr(tag)); + if (dest_rank_send == stage) { + return input; } } } - return false; + return nullptr; } -std::pair PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { - std::set tag_set; - auto node_stage = node->stage(); - int64_t min_tag = node_stage; - for (auto &user_pair : node_users) { - auto user_node = user_pair.first; - auto user_node_stage = user_node->stage(); - tag_set.insert(user_node_stage); - if (user_node_stage == -1) { - continue; - } - min_tag = min_tag > user_node_stage ? user_node_stage : min_tag; - } - bool is_shared = tag_set.size() > 1; - return std::make_pair(is_shared, min_tag); -} - -void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { +std::pair, std::vector> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { OperatorAttrs depend_attrs; - auto depend_op = CreatOpInstance(depend_attrs, "Depend", ""); - std::vector out_input = {NewValueNode(depend_op)}; + auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); + std::vector receive_ops; + std::vector send_ops; auto all_nodes = graph->nodes(); for (auto &node : all_nodes) { if (!node->isa() || node->stage() == -1) { continue; } auto node_users = manager_->node_users()[node]; - auto shared_min_tag_pair = IsSharedNode(node, node_users); - auto is_shared = shared_min_tag_pair.first; - auto min_tag = shared_min_tag_pair.second; AnfNodePtr receive = nullptr; for (auto &user_pair : node_users) { auto user_node = user_pair.first; @@ -561,21 +638,25 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { if (node_stage != stage_ && user_node_stage != stage_) { continue; } + auto micro = user_node->cast()->GetPrimalAttr(MICRO); + if (!micro) { + MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)"; + micro = MakeValue(int64_t(0)); + } if (node_stage < user_node_stage) { - if (is_shared && (min_tag != node_stage)) { - continue; - } if (node_stage == stage_) { - if (Reuse(node, user_node_stage, node_stage, out_input)) { + if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) { continue; } - auto send_out = InsertSend(graph, node, user_node_stage, node_stage); - out_input.insert(out_input.begin() + 1, send_out.depend); - type_ptr_ = send_out.type; - shape_ = send_out.shape; + auto send_out = InsertSend(graph, node, user_node_stage, node_stage, micro); + MS_EXCEPTION_IF_NULL(send_out.depend); + send_ops.push_back(send_out.depend); + send_out.depend->set_user_data(DTYPE, send_out.type); + send_out.depend->set_user_data(SHAPE, send_out.shape); } else { if (!receive) { - receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage); + receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro); + receive_ops.push_back(receive); } else { manager_->SetEdge(user_node, user_pair.second, receive); } @@ -583,46 +664,40 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { continue; } if (node_stage > user_node_stage) { - auto cnode = node->cast(); - auto user_cnode = user_node->cast(); - if (IsValueNode(cnode->input(0)) && IsValueNode(user_cnode->input(0))) { - MS_LOG(EXCEPTION) << "Don't support this situation"; - } - continue; + MS_LOG(EXCEPTION) << "node_stage: " << node_stage + << " must be smaller than user_node_stage: " << user_node_stage; } } } - if (out_input.size() == 2) { - manager_->Replace(graph->output(), out_input[1]); - } - if (out_input.size() > 2) { - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - make_tuple_inputs.insert(make_tuple_inputs.begin() + 1, out_input.begin() + 2, out_input.end()); - auto make_tuple = graph->NewCNode(make_tuple_inputs); - std::vector out_depend_inputs = {out_input[0], out_input[1], make_tuple}; - auto out_node = graph->NewCNode(out_depend_inputs); - manager_->Replace(graph->output(), out_node); - } + return std::make_pair(send_ops, receive_ops); } void PipelineTransformer::CutGraph() { - for (auto &fg : manager_->func_graphs()) { - CutBorder(fg); + std::vector make_tuple_inputs; + CreateForwardGroup(); + MS_EXCEPTION_IF_NULL(main_graph_); + if (make_tuple_inputs.empty()) { + make_tuple_inputs = HandleSharedParameter(); } -} - -bool PipelineTransformer::IsStageNode(const CNodePtr &node) { - for (auto &input : node->inputs()) { - if (input->isa()) { - return (*parameter_color_map[input].begin() == stage_ || input->stage() == -1); - } else if (input->isa()) { - auto pre_node = input->cast(); - return IsStageNode(pre_node); - } else { - continue; - } + auto send_recv_ops = CutBorder(main_graph_); + auto send_ops = send_recv_ops.first; + if (IsLastStage()) { + return; } - return true; + if (send_ops.empty() && !root_->has_flag(TRAINING)) { + return; + } + make_tuple_inputs.insert(make_tuple_inputs.end(), send_ops.begin(), send_ops.end()); + if (!send_ops.empty()) { + type_ptr_ = send_ops.back()->user_data(DTYPE); + shape_ = send_ops.back()->user_data(SHAPE); + } + auto make_tuple = main_graph_->NewCNode(make_tuple_inputs); + std::vector out = {NewValueNode(prim::kPrimDepend)}; + out.push_back(send_ops.back()); + out.push_back(make_tuple); + auto out_node = main_graph_->NewCNode(out); + manager_->Replace(main_graph_->output(), out_node); } void PipelineTransformer::ElimGraphStage() { @@ -694,7 +769,21 @@ void PipelineTransformer::ElimParameter() { std::vector parameter_list; for (auto ¶meter : parameters) { if (!manager_->node_users()[parameter].empty()) { - parameter_list.push_back(parameter); + if (!root_->has_flag(TRAINING)) { + for (auto &node_pair : manager_->node_users()[parameter]) { + auto user_node = node_pair.first; + if (!IsPrimitiveCNode(user_node, prim::kPrimReceive)) { + parameter_list.push_back(parameter); + break; + } + // remove_receive_inputs + auto cnode = user_node->cast(); + std::vector new_inputs = {cnode->input(0)}; + cnode->set_inputs(new_inputs); + } + } else { + parameter_list.push_back(parameter); + } } } auto del_num = parameters.size() - parameter_list.size(); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index 1d441a67cab..69bedb43410 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -45,13 +45,15 @@ class PipelineTransformer { : manager_(manager), stage_(stage), root_(root), + main_graph_(nullptr), + virtual_dataset_(nullptr), global_rank_(global_rank), per_stage_rank_num_(per_stage_rank_num) {} virtual ~PipelineTransformer() = default; - void LabelRequiredGradCNode(); void Coloring(); + void MainGraph(); + void LabelMicroBatch(); void BroadCastColoring(); - void HandleSharedParameter(); void CutGraph(); void ParameterColoring(); void CoverSensShape(); @@ -59,21 +61,18 @@ class PipelineTransformer { void ElimParameter(); private: - std::pair IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); - void DoBroadCast(const FuncGraphPtr &func); + void CreateForwardGroup(); + ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size); + std::vector HandleSharedParameter(); SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int64_t user_node_stage, - int64_t node_stage); + int64_t node_stage, const ValuePtr &value); AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, - int64_t user_node_stage, int64_t node_stage); - void SetNoStageNode(const FuncGraphPtr &func); - void CutBorder(const FuncGraphPtr &graph); - bool IsStageNode(const CNodePtr &node); - bool Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage, - const std::vector &out_input); + int64_t user_node_stage, int64_t node_stage, const ValuePtr &value); + std::pair, std::vector> CutBorder(const FuncGraphPtr &graph); + AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector &out_input, + const std::string &tag); AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); std::pair GetOpInfo(const AnfNodePtr &node); - AnfNodePtr HandleMonadDepend(const AnfNodePtr &node); - CNodePtr HandleMonadLoad(const AnfNodePtr &node); std::pair GetParameterPair(const AnfNodePtr &node); OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); bool IsPipelineCareNode(const CNodePtr &cnode); @@ -81,11 +80,15 @@ class PipelineTransformer { FuncGraphManagerPtr manager_; int64_t stage_; FuncGraphPtr root_; + FuncGraphPtr main_graph_; + AnfNodePtr virtual_dataset_; int64_t global_rank_; int64_t per_stage_rank_num_; TypePtr type_ptr_; ValueListPtr shape_; AnfNodePtr virtual_param_; + int64_t micro_size_ = 0; + std::vector group_ = {}; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 5438a206c71..ca57f2f52af 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -37,6 +37,7 @@ #include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/graph_util/pipeline_split_utils.h" #include "frontend/parallel/node_check.h" #include "frontend/parallel/ops_info/matmul_info.h" #include "ir/param_info.h" @@ -172,8 +173,9 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat OperatorArgs arg_forward = op.second; int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); - if (grad_accumulation_step > 1) { + if (grad_accumulation_step > 1 || split_stage_num > 1) { auto parameters = root->parameters(); bool find_grad_accu_node = false; for (auto ¶m : parameters) { @@ -196,8 +198,8 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat if (op_name == MIRROR_MINI_STEP_OPERATOR) { op_name = MIRROR_OPERATOR; arg_forward.first.pop_back(); - } else if (op_name == MINI_STEP_ALL_GATHER) { - MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation."; + } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) { + MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name; } } } @@ -207,7 +209,8 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat OperatorParams params = arg_forward.second; std::vector new_node_input; - if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) { + if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER || + op_name == MIRROR_MICRO_STEP_OPERATOR) { new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; } else { @@ -496,6 +499,9 @@ void Redistribution(const std::pair &node_pair, const Opera TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)]; TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); + if (IsPrimitiveCNode(middle_node, prim::kPrimReceive)) { + tensorlayout_in = *(middle_node->user_data()); + } if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " @@ -866,11 +872,13 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { SetUserAttrs(origin_prim->attrs(), prim); if (index == replace_op.size() - 1) { replace_node->set_user_data(node->user_data()); + replace_node->set_primal_attrs(node->primal_attrs()); } replace_node->set_in_forward_flag(true); replace_input[0]->set_scope(scope); if (replace_op_info_flag && replace_op_info[index].first) { auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph); + new_cnode->set_primal_attrs(node->primal_attrs()); (void)manager->Replace(node, new_cnode); // using Replace function to insert node } else { (void)manager->Replace(node, replace_node); // using Replace function to insert node @@ -920,8 +928,9 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node manager->SetEdge(replace_input.first, appear_count, pre_node); } // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called - auto replace_output = replace_graph->second; + auto replace_output = replace_graph->second->cast(); MS_EXCEPTION_IF_NULL(replace_output); + replace_output->set_primal_attrs(node->primal_attrs()); (void)manager->Replace(node, replace_output); } @@ -1075,7 +1084,7 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap } } - if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data()) { + if (IsSomePrimitive(cnode, RECEIVE) && cnode->has_user_data(PIPELINE_PARAM)) { return std::make_pair(node, false); } // When not fully use opt shard, allgather and mirror would be both inserted. @@ -1193,6 +1202,20 @@ CNodePtr SkipTrivialNodes(CNodePtr node) { return node; } +std::string MirrorOpName() { + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); + std::string mirror_op_name; + if (grad_accumulation_step > 1) { + mirror_op_name = MIRROR_MINI_STEP_OPERATOR; + } else if (split_stage_num > 1) { + mirror_op_name = MIRROR_MICRO_STEP_OPERATOR; + } else { + mirror_op_name = MIRROR_OPERATOR; + } + return mirror_op_name; +} + void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); @@ -1240,12 +1263,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons } } // not a RefKey - int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); - std::string mirror_op_name; - if (grad_accumulation_step > 1) { - mirror_op_name = MIRROR_MINI_STEP_OPERATOR; - } else { - mirror_op_name = MIRROR_OPERATOR; + std::string mirror_op_name = MirrorOpName(); + if (IsPrimitiveCNode(param_node_pair.first, prim::kPrimReceive)) { + param_ptr = param_node_pair.first->cast()->user_data(PIPELINE_PARAM)->cast(); + param_name = param_ptr->name(); } AnfNodePtr pre_node = node->input(index); if (!param_node_pair.second) { @@ -1282,6 +1303,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons AddCommOpFusionType(comm_op, param_node_pair.first); MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast()) << " and insert mirror before Load"; + AddCommOpParamFlag(comm_op); continue; } InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root); @@ -1291,6 +1313,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons // add fusion flag // pipeline mirror would not be set, which should be supported later AddCommOpFusionType(comm_op, param_node_pair.first); + AddCommOpParamFlag(comm_op); } } @@ -2333,6 +2356,9 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (!IsValueNode(cnode->input(0))) { return nullptr; } + if (IsPrimitiveCNode(node, prim::kPrimReceive)) { + return cnode->user_data(); + } if (IsParallelCareNode(cnode) && cnode->has_user_data() && !IsPrimitiveCNode(node, prim::kPrimReshape)) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); @@ -2764,13 +2790,6 @@ std::vector> GetSensLossPairs(const FuncGraphP return sens_loss_pairs; } -bool IsLastStage() { - MS_EXCEPTION_IF_NULL(g_device_manager); - auto stage_num = g_device_manager->stage_num(); - auto stage_id = g_device_manager->stage_id(); - return ((stage_num - 1) == stage_id); -} - void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); @@ -2793,7 +2812,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorisa()) { auto cnode = node->cast(); // the make_tuple is parallel care node, but it may have not operator info - if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data() || cnode->HasPrimalAttr(PIPELINE_PARAM)) { continue; } @@ -3545,20 +3564,6 @@ static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) { return false; } -static AnfNodePtr FindGradAccuParameter(const std::vector ¶meters, const std::string &name) { - for (auto ¶meter : parameters) { - auto param_ptr = parameter->cast(); - MS_EXCEPTION_IF_NULL(param_ptr); - if (param_ptr->name() == name) { - continue; - } - if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) { - return parameter; - } - } - return nullptr; -} - static void InsertFullySplitParamGradAccu(const std::pair &node_user, const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) { auto cnode = node_user.first->cast(); @@ -3612,6 +3617,17 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) { } } +void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages) { + if (!root->has_flag(BACKWARD) && pipeline_stages > 1) { + root->set_flag(BACKWARD, true); + if (root->has_flag(TRAINING)) { + Reorder(root, manager); + } else { + ReorderForPredict(root, manager); + } + } +} + bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { #if (ENABLE_CPU && !_WIN32) if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { @@ -3622,6 +3638,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + pipeline::ResourceBasePtr res = optimizer->resource(); + MS_EXCEPTION_IF_NULL(res); + FuncGraphManagerPtr manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num(); // assume no change to graph bool changes = false; // control whether use model_parallel mode @@ -3634,6 +3655,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) } root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); } + ReorderForPipelineSplit(root, manager, pipeline_stages); return changes; } @@ -3643,23 +3665,22 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) MS_LOG(INFO) << "Now entering step parallel"; DumpGraph(root, std::string(STEP_PARALLEL_BEGIN)); - - pipeline::ResourceBasePtr res = optimizer->resource(); - MS_EXCEPTION_IF_NULL(res); - - FuncGraphManagerPtr manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); AnfNodePtr ret = root->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector all_nodes = DeepScopedGraphSearch(ret); std::reverse(all_nodes.begin(), all_nodes.end()); if (parallel_mode != AUTO_PARALLEL) { TOTAL_OPS = 0; - auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num(); if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) { MS_LOG(EXCEPTION) << "Parallel init failed"; } + if (pipeline_stages > 1) { + HandleMicroBatch(all_nodes, manager); + ParameterStartNode(all_nodes, manager); + LastStageEndNode(all_nodes, manager); + } + // mark the forward cnodes, parallel only care these nodes MarkForwardCNode(root); @@ -3705,6 +3726,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // ForwardCommunication BackwardCommunication TensorRedistribution ParallelCommunication(root, all_nodes, manager); + if (pipeline_stages > 1) { + AddVirtualAssignAdd(root); + HandleReceiveParam(root, all_nodes); + } + auto group_info = g_device_manager->group_info(); if (StrategyCheckpoint::GetInstance().group_info_save_on() && StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 5e15d6b562f..2e9d745b812 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -134,8 +134,6 @@ void ReshapeInit(const std::vector &all_nodes); StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim); -bool IsLastStage(); - // Add node for whole graph void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); @@ -177,6 +175,10 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector *u std::vector *indexes); void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector &all_nodes); + +std::string MirrorOpName(); + +void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 02bc1650641..1fdb2d15d82 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -407,7 +407,9 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.env_get_item_depend_swap_, irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, - irpass.receive_eliminate_}, + irpass.virtual_accu_grad_, + irpass.virtual_assign_add_, + irpass.mirror_micro_step_}, false, true); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc index f65d584a86f..9290b5b78fc 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc @@ -90,17 +90,19 @@ bool PipelineSplit(const ResourcePtr &res) { auto transformer = std::make_shared(manager, stage, root, global_rank, per_stage_rank_num); // step1: Do color graph - transformer->LabelRequiredGradCNode(); transformer->Coloring(); + transformer->MainGraph(); + transformer->LabelMicroBatch(); // step2: Do color broadcast transformer->BroadCastColoring(); // step3: Handle shared parameters transformer->ParameterColoring(); - transformer->HandleSharedParameter(); // step4: Cut Graph transformer->CutGraph(); // step5: Handle Sens - transformer->CoverSensShape(); + if (root->has_flag(parallel::TRAINING)) { + transformer->CoverSensShape(); + } // step6: Elim Graph stages and no used parameter transformer->ElimGraphStage(); transformer->ElimParameter(); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8e496928107..8a6d441e3e3 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -373,6 +373,9 @@ inline const PrimitivePtr kPrimFill = std::make_shared("Fill"); inline const PrimitivePtr kPrimFusedPushWeight = std::make_shared("FusedPushWeight"); inline const PrimitivePtr kPrimFusedPullWeight = std::make_shared("FusedPullWeight"); inline const PrimitivePtr kPrimInitDataSetQueue = std::make_shared("InitDataSetQueue"); +inline const PrimitivePtr kPrimVirtualAssignAdd = std::make_shared("_VirtualAssignAdd"); +inline const PrimitivePtr kPrimVirtualAccuGrad = std::make_shared("_VirtualAccuGrad"); +inline const PrimitivePtr kPrimMirrorMicroStep = std::make_shared("_MirrorMicroStepOperator"); // Quant ops inline const PrimitivePtr kPrimBatchNormFold = std::make_shared("BatchNormFold"); diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index f0001702730..3ffdd2c0fd2 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -19,6 +19,7 @@ from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, _get_parallel_mode) from mindspore.context import ParallelMode, get_auto_parallel_context from mindspore._checkparam import Validator as validator +from mindspore import ops, nn from ...common import dtype as mstype from ...common.parameter import Parameter, ParameterTuple from ...common.tensor import Tensor @@ -503,6 +504,95 @@ class _VirtualDatasetCell(Cell): return self._backbone(*output) +class _MicroBatch(Cell): + """ + transform mini-batch to micro-batch in pipeline parallel. + + Args: + params (micro_size): The number of micro-batch. + """ + def __init__(self, micro_size): + super(_MicroBatch, self).__init__() + self.shape = P.Shape() + self.micro_size = micro_size + + def construct(self, i, *inputs): + micro_inputs = () + for each_input in inputs: + input_shape = self.shape(each_input) + micro_batch_begin = i * input_shape[0] // self.micro_size + micro_batch_end = (i + 1) * input_shape[0] // self.micro_size + micro_input = each_input[micro_batch_begin:micro_batch_end, :] + micro_inputs += (micro_input,) + return micro_inputs + + +class PipelineCell(Cell): + """ + Wrap the network with Micro Batch. + + Note: + micro_size must be greater or equal to pipeline stages. + + Args: + network (Cell): The target network to wrap. + micro_size (Int): MicroBatch size. + + Examples: + >>> net = Net() + >>> net = PipelineCell(net, 4) + """ + def __init__(self, network, micro_size): + super(PipelineCell, self).__init__() + self.network = network + self.micro_inputs = nn.CellList() + self.micro_size = micro_size + self.add_list = [] + for i in range(micro_size): + micro_input = _MicroBatch(micro_size) + self.micro_inputs.append(micro_input) + self.add = P.Add().add_prim_attr("pipeline_end", i) + self.add_list.append(self.add) + + def construct(self, *inputs): + ret = None + for i in range(self.micro_size): + micro_input = self.micro_inputs[i](i, *inputs) + output = self.network(*micro_input) + if ret is not None: + ret = self.add_list[i](ret, output) + else: + ret = output + return ret + + +def _pipeline_clear_grad(accu_grad, grad): + accu_grad = F.depend(accu_grad, grad) + zeros = F.tensor_mul(accu_grad, 0.0) + return F.assign(accu_grad, zeros) + + +class _TrainPipelineAccuStepCell(TrainOneStepCell): + """ + Wraps the network with an optimizer in pipeline mode. + """ + def __init__(self, network, optimizer, sens=1.0): + super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens) + self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") + self.hyper_map = ops.HyperMap() + + def construct(self, *inputs): + weights = self.weights + loss = self.network(*inputs) + sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*inputs, sens) + accu_grads = ops.depend(self.accu_grads, grads) + succ = self.optimizer(accu_grads) + clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads) + loss = ops.depend(loss, succ, clear) + return loss + + class VirtualDatasetCellTriple(Cell): """ Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index d41cc584201..7e59f92e53d 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """Generate bprop for comm ops""" +from mindspore import Tensor import mindspore.common.dtype as mstype from mindspore.ops import functional as F from mindspore.communication import get_rank, get_group_size @@ -22,7 +23,8 @@ from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, - ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap) + ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap, + _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive @@ -84,11 +86,11 @@ def get_bprop_send(self): """Generate bprop for Send.""" shape = self.get_attr_dict()["shape"] dtype = self.get_attr_dict()["dtype"] - send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) - send_grad.add_prim_attr("backward", True) + send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group_back) + virtual_input = Tensor(0.0, dtype) def bprop(x, out, dout): - dx = send_grad() + dx = send_grad(virtual_input) return (dx,) return bprop @@ -96,14 +98,14 @@ def get_bprop_send(self): @bprop_getters.register(Receive) def get_bprop_receive(self): """Generate bprop for Receive.""" - receive_grad = Send(self.tag, self.rank, self.group) - receive_grad.add_prim_attr("backward", True) + receive_grad = Send(self.tag, self.rank, self.group_back) depend = P.Depend() cast = P.Cast() + out_tensor = Tensor(0.0, mstype.float16) def bprop(x, out, dout): send_out = receive_grad(dout) - dx = depend(cast(zeros_like(x), F.dtype(x)), send_out) + dx = depend(cast(out_tensor, F.dtype(x)), send_out) return (dx,) return bprop @@ -116,6 +118,80 @@ def get_bprop_virtual_add(self): return bprop +@bprop_getters.register(_VirtualAssignAdd) +def get_bprop_virtual_assign_add(self): + """Generate bprop for VirtualAssignAdd.""" + assign_add = P.AssignAdd() + cast = P.Cast() + dtype = P.DType() + out_tensor = Tensor(0.0, mstype.float16) + + def bprop(x, y, out, dout): + temp = assign_add(y, dout) + return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp) + + return bprop + + +@bprop_getters.register(_VirtualAccuGrad) +def get_bprop_virtual_accu_grad(self): + """Generate bprop for VirtualAccuGrad.""" + cast = P.Cast() + dtype = P.DType() + out_tensor = Tensor(0.0, mstype.float16) + + def bprop(x, y, out, dout): + return (F.depend(y, dout), cast(out_tensor, dtype(y))) + + return bprop + + +@bprop_getters.register(_MirrorMicroStepOperator) +def get_bprop_mirror_micro_step_operator(self): + """ + Backpropagator for _MirrorMicroStepOperator, do allreduce or allgather for the devices in the group, + allgather for sparse feature. + """ + group = self.group + dev_num = self.dev_num + mean_flag = self.mean_flag + scale = 1 / dev_num + + all_reduce = AllReduce(group=group) + + fusion = self.get_attr_dict()["fusion"] + all_reduce.add_prim_attr("fusion", fusion) + if hasattr(self, 'parameter'): + parameter = self.parameter + all_reduce.add_prim_attr("parameter", parameter) + + if self.instance_name: + instance_name = "grad_mirror" + self.instance_name + all_reduce.set_prim_instance_name(instance_name) + cast = P.Cast() + dtype = P.DType() + assign = P.Assign() + if "parameter_micro" in self.get_attr_dict(): + assign.add_prim_attr("parameter_micro", 0) + out_tensor = Tensor(1.0, mstype.float16) + + def bprop(x, z, out, dout): + real_grad = z + if mean_flag: + if F.issubclass_(F.typeof(dout), mstype.tensor): + z = F.depend(z, dout) + real_grad = all_reduce(z) + real_grad = F.tensor_mul(real_grad, scale) + assign(z, real_grad) + else: + if F.issubclass_(F.typeof(dout), mstype.tensor): + z = F.depend(z, dout) + real_grad = all_reduce(z) + assign(z, real_grad) + return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))) + return bprop + + @bprop_getters.register(Broadcast) def get_bprop_broad_cast(self): """Generate bprop for Broadcast.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index bb0cef6ad9e..62b8c49c5fe 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -36,8 +36,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect, SearchSorted) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, - _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, - _HostAllGather, _HostReduceScatter) + _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, + _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import GeSwitch, Merge diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 2c346a819a5..cfdefe99023 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -417,7 +417,7 @@ class Send(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): + def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP): self.rank = dest_rank self.sr_tag = sr_tag self.group = group @@ -427,7 +427,6 @@ class Send(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - self.add_prim_attr("dtype", x_dtype) return x_dtype @@ -474,7 +473,8 @@ class Receive(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): + def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP, + group_back=GlobalComm.WORLD_COMM_GROUP): self.rank = src_rank self.tag = sr_tag self.shape = shape diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 0a60c5d13ad..bf4ca9bca84 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -690,6 +690,72 @@ class _VirtualDataset(PrimitiveWithInfer): virtual_dataset = _VirtualDataset() + +class _VirtualAssignAdd(PrimitiveWithInfer): + """ + Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for + internal use of parallel modules and cannot be called by users. + + Args: + micro (int): MicroBatch. Default: 0. + """ + @prim_attr_register + def __init__(self): + """init""" + + def infer_shape(self, x_shape, y_shape): + return x_shape + + def infer_dtype(self, x_dtype, y_dtype): + return x_dtype + + +virtual_assign_add = _VirtualAssignAdd() + + +class _VirtualAccuGrad(PrimitiveWithInfer): + """ + Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for + internal use of parallel modules and cannot be called by users. + """ + @prim_attr_register + def __init__(self): + """init""" + + def infer_shape(self, x_shape, y_shape): + return x_shape + + def infer_dtype(self, x_dtype, y_dtype): + return x_dtype + + +virtual_accu_grad = _VirtualAccuGrad() + + +class _MirrorMicroStepOperator(PrimitiveWithInfer): + """ + Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for + internal use of parallel modules and cannot be called by users. + + Args: + group (str): The communication group to work on. Default: None. + dev_num (int): The device number of the group. Default: None. + mean_flag (bool): Whether use mean in backward. Default: None. + """ + + @prim_attr_register + def __init__(self, group=None, dev_num=None, mean_flag=None): + self.group = group + self.dev_num = dev_num + self.mean_flag = mean_flag + + def infer_shape(self, x_shape, z_shape): + return x_shape + + def infer_dtype(self, x_dtype, z_shape): + return x_dtype + + class _VirtualOutput(PrimitiveWithInfer): """ Auto parallel virtual out operator. diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 070cad4fa1f..33b557482f2 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -20,9 +20,9 @@ from .._checkparam import Validator as validator from .._checkparam import Rel from ..common import dtype as mstype from ..nn import acc -from ..nn.wrap.cell_wrapper import _VirtualDatasetCell +from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell from ..ops import functional as F -from ..parallel._utils import _get_parallel_mode +from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager from ..context import ParallelMode from .. import context @@ -190,5 +190,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): network = nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense=update_cell).set_train() return network - network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() + if _get_pipeline_stages() > 1: + network = _TrainPipelineAccuStepCell(network, optimizer).set_train() + else: + network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() return network diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index 957586ddcf8..6d4d1bf1bab 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -21,6 +21,7 @@ from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.train.model import Model +from mindspore.nn.wrap.cell_wrapper import PipelineCell class DatasetLenet(): @@ -90,6 +91,7 @@ class PipelineSplit(nn.Cell): def __init__(self, strategy1, strategy2): super().__init__() self.cell = Net(strategy1, strategy2) + self.cell.block[0].matmul.add_prim_attr("parameter_start", 0) def construct(self, x, label): x = self.cell(x) @@ -101,6 +103,7 @@ class PipelineSplit2(nn.Cell): super().__init__() self.param = Parameter(initializer("zeros", [64, 64]), name="param") self.cell = Net(strategy1, strategy2, self.param) + self.cell.block[0].matmul.add_prim_attr("parameter_start", 0) def construct(self, x, label): x = self.cell(x) @@ -114,8 +117,8 @@ def test_pipeline_split_stage0(): label = Tensor(np.ones([64, 64]), dtype=ms.float32) strategy1 = ((4, 1), (1, 1)) strategy2 = ((2, 1), (1, 1)) - net = PipelineSplit(strategy1, strategy2) - params = net.cell.block[0].trainable_params() + net = PipelineCell(PipelineSplit(strategy1, strategy2), 4) + params = net.network.cell.block[0].trainable_params() dataset = DatasetLenet(data, label, 3) optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer) @@ -131,8 +134,8 @@ def test_pipeline_split_stage1(): label = Tensor(np.ones([64, 64]), dtype=ms.float32) strategy1 = ((4, 1), (1, 1)) strategy2 = ((2, 1), (1, 1)) - net = PipelineSplit(strategy1, strategy2) - params = net.cell.block[1].trainable_params() + net = PipelineCell(PipelineSplit(strategy1, strategy2), 4) + params = net.network.cell.block[1].trainable_params() dataset = DatasetLenet(data, label, 3) optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer) @@ -149,8 +152,8 @@ def test_pipeline_split_shared_parameter_stage0(): label = Tensor(np.ones([64, 64]), dtype=ms.float32) strategy1 = ((4, 1), (1, 1)) strategy2 = ((2, 1), (1, 1)) - net = PipelineSplit2(strategy1, strategy2) - params = net.cell.block[0].trainable_params() + net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4) + params = net.network.cell.block[0].trainable_params() dataset = DatasetLenet(data, label, 3) optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer) @@ -164,8 +167,8 @@ def test_pipeline_split_shared_parameter_stage1(): label = Tensor(np.ones([64, 64]), dtype=ms.float32) strategy1 = ((4, 1), (1, 1)) strategy2 = ((2, 1), (1, 1)) - net = PipelineSplit2(strategy1, strategy2) - params = net.cell.block[1].trainable_params() + net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4) + params = net.network.cell.block[1].trainable_params() dataset = DatasetLenet(data, label, 3) optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer)