From 0efd0c4b256a12ec9e55211e4680519a6c9948f1 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Thu, 11 Aug 2022 19:16:58 +0800 Subject: [PATCH] Add general strategy --- mindspore/ccsrc/distributed/constants.h | 14 +- .../parallel/graph_util/graph_splitter.cc | 135 +++++++++++++++--- .../parallel/graph_util/graph_splitter.h | 52 ++++++- mindspore/ccsrc/pipeline/jit/action.cc | 6 +- .../device/cpu/kernel/rpc/rpc_send_kernel.cc | 2 +- mindspore/python/mindspore/nn/cell.py | 9 ++ mindspore/python/mindspore/ops/primitive.py | 8 ++ .../test_all_reduce/run_allreduce.py | 1 + .../run_allreduce_small_scale_data.py | 1 + 9 files changed, 202 insertions(+), 26 deletions(-) diff --git a/mindspore/ccsrc/distributed/constants.h b/mindspore/ccsrc/distributed/constants.h index e86211017e6..9e3f3e03d35 100644 --- a/mindspore/ccsrc/distributed/constants.h +++ b/mindspore/ccsrc/distributed/constants.h @@ -49,7 +49,19 @@ const std::vector kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpd constexpr char kFinalizeMuxRecvActor[] = "FINALIZE_MUX_RECV_ACTOR"; // The distributed execution mode enum. -enum class DistExecutionMode { kPSMode = 0, kEmbeddingCacheMode, kInvalidMode }; +// For each execution mode, different graph optimization, splitting strategy, device location, etc are applied. For +// details please refer to class DistributedExecutionMode and its subclasses. + +// kGeneralMode: Simply split a training graph into multiple devices without other extra features. + +// kParallelMode: MindSpore's existing auto-parallel feature along with distributed graph splitting feature are +// combined. This is much more complicated than other mode. It is always applied in MoE scenarios. + +// kPSMode: Applied when running Parameter Server training. + +// kEmbeddingCacheMode: Applied when embedding cache is enabled. Normally used for training models with large embedding +// layer. +enum class DistExecutionMode { kGeneralMode = 0, kParallelMode, kPSMode, kEmbeddingCacheMode, kInvalidMode }; // The operator's label in distributed execution. constexpr char kOpLabelRankId[] = "rank_id"; diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc index 29f3593d8f6..86e30293bb8 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc @@ -27,6 +27,7 @@ #include "mindspore/core/utils/ms_context.h" #include "include/common/utils/anfalgo.h" #include "include/common/debug/draw.h" +#include "include/common/utils/parallel_context.h" #ifdef WITH_BACKEND #include "ps/ps_context.h" #endif @@ -39,18 +40,17 @@ bool OperatorLabel::operator==(const OperatorLabel &label) const { return to_str bool OperatorLabel::operator!=(const OperatorLabel &label) const { return !(*this == label); } -bool OperatorLabel::LooseEqual(const OperatorLabel &label) const { - auto mode = distributed::DistExecutionMode::kPSMode; +bool OperatorLabel::LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const { if (kLabelMatchingFuncMap.count(mode) == 0) { - MS_LOG(ERROR) << "The mode " << mode << " is invalid."; - return false; + MS_LOG(DEBUG) << "The mode " << mode << " does not need LooseEqual."; + return to_string() == label.to_string(); } return kLabelMatchingFuncMap.at(mode)(label, *this); } std::string OperatorLabel::to_string() const { return std::to_string(rank_id) + "_" + ms_role; } -ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node) { +ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node, bool use_fake_shape) { tensor::TensorPtr fake_tensor = nullptr; if (use_origin_node) { MS_EXCEPTION_IF_NULL(origin_node); @@ -63,15 +63,26 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_ origin_abstract = origin_node->abstract()->cast(); } MS_EXCEPTION_IF_NULL(origin_abstract); - fake_tensor = std::make_shared(origin_abstract->element()->BuildType()->type_id(), - origin_abstract->shape()->shape()); - MS_EXCEPTION_IF_NULL(fake_tensor); - fake_tensor->set_base_shape(origin_abstract->shape()->Clone()); + auto element = origin_abstract->element(); + MS_EXCEPTION_IF_NULL(element); + auto build_type = element->BuildType(); + MS_EXCEPTION_IF_NULL(build_type); + auto type_id = build_type->type_id(); + if (use_fake_shape) { + // Assign send's output shape as {1}; + ShapeVector fake_shape = {kSizeOne}; + fake_tensor = std::make_shared(type_id, fake_shape); + } else { + auto shape = origin_abstract->shape(); + MS_EXCEPTION_IF_NULL(shape); + fake_tensor = std::make_shared(type_id, shape->shape()); + fake_tensor->set_base_shape(shape->Clone()); + } } else { fake_tensor = std::make_shared(1.0); - MS_EXCEPTION_IF_NULL(fake_tensor); } + MS_EXCEPTION_IF_NULL(fake_tensor); auto fake_value = NewValueNode(fake_tensor); MS_EXCEPTION_IF_NULL(fake_value); fake_value->set_abstract(fake_tensor->ToAbstract()); @@ -249,8 +260,8 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge if (src_node->isa() && common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, src_node->cast()) && common::AnfAlgo::HasNodeAttr(kAttrParameterInputIndex, src_node->cast())) { int64_t parameter_index = common::AnfAlgo::GetNodeAttr(src_node, kAttrParameterInputIndex); - auto kernel_with_index = - common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast(), parameter_index), 0); + auto kernel_with_index = common::AnfAlgo::VisitKernel( + common::AnfAlgo::GetInputNode(src_node->cast(), parameter_index), kIndex0); auto param_node = kernel_with_index.first; recv_inputs.push_back(param_node); @@ -264,7 +275,8 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge recv_node_abs = param_node->abstract(); } else { - auto mock_value = CreateFakeValueNode(true, src_node); + // Use the same shape as origin node's. + auto mock_value = CreateFakeValueNode(true, src_node, false); MS_EXCEPTION_IF_NULL(mock_value); recv_inputs.push_back(mock_value); recv_node_abs = src_node->abstract(); @@ -320,6 +332,86 @@ bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &inp return std::count(all_inputs.begin(), all_inputs.end(), input) != 0; } +distributed::DistExecutionMode GenerateStrategy() { + distributed::DistExecutionMode strategy; + bool enable_ps = false; + bool enable_embedding_cache = false; +#ifdef WITH_BACKEND + enable_ps = ps::PSContext::instance()->is_ps_mode(); + enable_embedding_cache = ps::PSContext::instance()->cache_enable(); +#endif + std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); + bool using_parallel = (parallel_mode != parallel::kStandalone) ? true : false; + // The conditions' priority is: EmbeddingCache > Parameter Server > General. + if (enable_embedding_cache) { + strategy = distributed::DistExecutionMode::kEmbeddingCacheMode; + } else if (enable_ps) { + strategy = distributed::DistExecutionMode::kPSMode; + } else if (using_parallel) { + strategy = distributed::DistExecutionMode::kParallelMode; + } else { + strategy = distributed::DistExecutionMode::kGeneralMode; + } + return strategy; +} + +void TransformPrimAttrToAttr(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto prim = GetValueNode(cnode->input(kIndex0)); + MS_EXCEPTION_IF_NULL(prim); + if (cnode->HasPrimalAttr(distributed::kOpLabelRankId)) { + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'rank_id'."; + prim->set_attr(distributed::kOpLabelRankId, cnode->GetPrimalAttr(distributed::kOpLabelRankId)); + } + if (cnode->HasPrimalAttr(distributed::kOpLabelRole)) { + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'ms_role'."; + prim->set_attr(distributed::kOpLabelRole, cnode->GetPrimalAttr(distributed::kOpLabelRole)); + } +} + +bool NodeHasLabel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + + bool has_label = false; + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto prim_node = cnode->input(0); + MS_EXCEPTION_IF_NULL(prim_node); + + // As long as the node has 'ms_role' and 'rank_id' attributes, we consider this node has label regardless the value of + // these two attributes. + if (IsValueNode(prim_node)) { + auto prim = GetValueNode(prim_node); + MS_EXCEPTION_IF_NULL(prim); + if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) { + has_label = true; + } + } else { + // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode. + if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) { + has_label = true; + } + } + return has_label; +} + +bool GraphHasLabel(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector all_nodes = DeepScopedGraphSearch(func_graph->get_return()); + // If one node has label, this graph has label. Thus it needs to be split. + for (const auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (NodeHasLabel(node)) { + return true; + } + } + return false; +} + void ParameterServerMode::PreBuildDistributedGraph() { MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode."; MS_EXCEPTION_IF_NULL(node_labels_); @@ -773,6 +865,8 @@ FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges( InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index}; FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0, std::get<2>(node_pair), std::get<3>(node_pair)); + std::vector pair_list = {fused_op_pair}; + results.insert(std::make_pair(edge_with_index, pair_list)); } } return results; @@ -896,12 +990,9 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c this_process_label_({rank_id, role}), node_labels_{}, need_fuse_rpc_nodes_(true) { - bool enable_embedding_cache = false; -#ifdef WITH_BACKEND - enable_embedding_cache = ps::PSContext::instance()->cache_enable(); -#endif - mode_ = enable_embedding_cache ? distributed::DistExecutionMode::kEmbeddingCacheMode - : distributed::DistExecutionMode::kPSMode; + // The distributed strategy is not explicitly defined by user. Distributed module generates the distributed strategy + // and default label according to some flags set by other modules. + mode_ = GenerateStrategy(); default_label_ = {0, distributed::kEnvRoleOfWorker}; } @@ -1044,7 +1135,7 @@ void GraphSplitter::DyeGraph() { } // If the node's label is the same as this process's, set its label to this_process_label_. - if (this_process_label_.LooseEqual(node_labels_[node])) { + if (this_process_label_.LooseEqual(node_labels_[node], mode_)) { node_labels_[node] = this_process_label_; } }); @@ -1059,6 +1150,8 @@ void GraphSplitter::CreateExecutionMode() { exec_mode_ = std::make_unique(func_graph_, &node_labels_, rank_id_, role_); } else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) { exec_mode_ = std::make_unique(func_graph_, &node_labels_, rank_id_, role_); + } else if (mode_ == distributed::DistExecutionMode::kGeneralMode) { + exec_mode_ = std::make_unique(func_graph_, &node_labels_, rank_id_, role_); } MS_EXCEPTION_IF_NULL(exec_mode_); } @@ -1170,8 +1263,10 @@ OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "Only CNode has distributed split label."; } CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); auto prim_node = cnode->input(0); if (IsValueNode(prim_node)) { + TransformPrimAttrToAttr(cnode); auto prim = GetValueNode(prim_node); MS_EXCEPTION_IF_NULL(prim); if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) { diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.h b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.h index 81bf98ffd49..aca803830c5 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.h @@ -53,7 +53,7 @@ struct OperatorLabel { // Judge whether the labels are equal but with looser conditions according to different modes. For example, this // method returns true when comparing the workers in PS mode. - bool LooseEqual(const OperatorLabel &label) const; + bool LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const; std::string to_string() const; }; @@ -79,8 +79,16 @@ inline bool MatchLabelForPSMode(const OperatorLabel &label1, const OperatorLabel } return false; } +inline bool MatchLabelForParallelMode(const OperatorLabel &label1, const OperatorLabel &label2) { + // When parallel mode is enabled by using MindSpore cluster, processes with the same role has the same label + // regardless of their rank id. + return (label1.ms_role == label2.ms_role); +} + const std::map kLabelMatchingFuncMap = { - {distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode}}; + {distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode}, + {distributed::DistExecutionMode::kEmbeddingCacheMode, MatchLabelForPSMode}, + {distributed::DistExecutionMode::kParallelMode, MatchLabelForParallelMode}}; // Split graph segment which is generated according to the topo sort of the graph. struct SplitGraphSegment { @@ -181,7 +189,8 @@ constexpr char kVirtualNode[] = "VirtualNode"; // This method creates a fake tensor. Its type is the same as the origin_node's output if use_origin_node is set // true. // Normally it is used to connect the edges for send/recv nodes. -ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr); +ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr, + bool use_fake_shape = true); // Create a TupleGetItem node from a node with tuple output. CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output, @@ -212,6 +221,33 @@ std::map GetRealIndexToSeg(const std::vector &split_segm bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input); +/** + * @description: Generate the distributed strategy according to user configuration. + * @return {distributed::DistExecutionMode}: The distributed strategy enum. + */ +distributed::DistExecutionMode GenerateStrategy(); + +/** + * @description: Transform primal attributes of cnode to normal attributes. + * @param {CNodePtr} &cnode: The cnode which has the primal attributes. + * @return {void} + */ +void TransformPrimAttrToAttr(const CNodePtr &cnode); + +/** + * @description: Judge whether this node has label. + * @param {AnfNodePtr} &node: AnfNode in a func_graph. + * @return {bool}: Whether this node has label. + */ +bool NodeHasLabel(const AnfNodePtr &node); + +/** + * @description: Judge whether this graph has any label. + * @param {FuncGraphPtr} &func_graph: The func_graph. + * @return {bool}: Whether this graph has label. + */ +bool GraphHasLabel(const FuncGraphPtr &func_graph); + // Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc. class DistributedExecutionMode { public: @@ -334,6 +370,16 @@ class EmbeddingCacheMode : public DistributedExecutionMode { OperatorLabel GetNodeLabel(const AnfNodePtr &node) const; }; +// Users may want to simply split a training graph into multiple devices without other extra features. GeneralMode is +// for this scenario. +class GeneralMode : public DistributedExecutionMode { + public: + explicit GeneralMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id, + const std::string &role) + : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {} + ~GeneralMode() = default; +}; + // The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the // cluster. class GraphSplitter { diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index e13095aec6a..5b1af258549 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -237,6 +237,9 @@ using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; +// Whether this process in a MindSpore cluster. +static bool is_cluster_initialized = false; + abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const FuncGraphPtr &func_graph, const abstract::AbstractBasePtrList &args_abs, bool clear) { MS_LOG(DEBUG) << "AbstractAnalyze start"; @@ -1379,7 +1382,7 @@ static std::vector CommonPipeline() { (void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); - if (!multi_graphs && pipeline::GetJitLevel() != "O0") { + if (!is_cluster_initialized && !multi_graphs && pipeline::GetJitLevel() != "O0") { (void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); } @@ -1420,6 +1423,7 @@ std::vector GePipeline() { } std::vector VmPipeline(const ResourcePtr &resource) { + is_cluster_initialized = distributed::cluster::ClusterContext::instance()->initialized(); std::vector actions; // If enable compilation cache and the cache is read successfully, only do the backend actions. if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/rpc/rpc_send_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/rpc/rpc_send_kernel.cc index 11054883e87..e04e7820d2c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/rpc/rpc_send_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/rpc/rpc_send_kernel.cc @@ -55,7 +55,7 @@ void RpcSendKernelMod::Init(const CNodePtr &kernel_node) { } std::vector RpcSendKernelMod::GetOpSupport() { - std::vector support_list = {KernelAttr().AddSkipCheckAttr(true).AddAllOutInRef(true)}; + std::vector support_list = {KernelAttr().AddSkipCheckAttr(true)}; return support_list; } diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index 5439e16836a..f3c71412f16 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -2159,6 +2159,15 @@ class Cell(Cell_): params.append(param) return params + def place(self, role, rank_id): + """ + Set the label for all operators in this cell. + This label tells MindSpore compiler on which process this cell should be launched. + """ + all_ops = self._get_prims_recursively() + for op in all_ops: + op.place(role, rank_id) + def _check_compile_dynamic_shape(self, *inputs): """ Check if graph has been compiled with dynamic shape. diff --git a/mindspore/python/mindspore/ops/primitive.py b/mindspore/python/mindspore/ops/primitive.py index 5811c6cc343..6a3e66aa4e4 100644 --- a/mindspore/python/mindspore/ops/primitive.py +++ b/mindspore/python/mindspore/ops/primitive.py @@ -384,6 +384,14 @@ class Primitive(Primitive_): self.add_prim_attr("recompute", mode) return self + def place(self, role, rank_id): + """ + Set the label for this primitive. + This label tells MindSpore compiler on which process this operator should be launched. + """ + self.add_prim_attr("ms_role", role) + self.add_prim_attr("rank_id", rank_id) + class PrimitiveWithCheck(Primitive): """ diff --git a/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce.py b/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce.py index 3eb3d4feca9..b420a1ea371 100644 --- a/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce.py +++ b/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce.py @@ -26,6 +26,7 @@ from mindspore.communication.management import init, get_group_size context.set_context(mode=context.GRAPH_MODE, device_target='CPU') context.set_ps_context(enable_ssl=False) init() +context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size()) class Net(nn.Cell): diff --git a/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce_small_scale_data.py b/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce_small_scale_data.py index 80ac622edb7..431c7d45efd 100644 --- a/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce_small_scale_data.py +++ b/tests/st/cpu_data_parallel/test_all_reduce/run_allreduce_small_scale_data.py @@ -26,6 +26,7 @@ from mindspore.communication.management import init, get_group_size context.set_context(mode=context.GRAPH_MODE, device_target='CPU') context.set_ps_context(enable_ssl=False) init() +context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size()) class Net(nn.Cell):