Add general strategy

This commit is contained in:
ZPaC 2022-08-11 19:16:58 +08:00
parent 2695f6c868
commit 0efd0c4b25
9 changed files with 202 additions and 26 deletions

View File

@ -49,7 +49,19 @@ const std::vector<std::string> kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpd
constexpr char kFinalizeMuxRecvActor[] = "FINALIZE_MUX_RECV_ACTOR";
// The distributed execution mode enum.
enum class DistExecutionMode { kPSMode = 0, kEmbeddingCacheMode, kInvalidMode };
// For each execution mode, different graph optimization, splitting strategy, device location, etc are applied. For
// details please refer to class DistributedExecutionMode and its subclasses.
// kGeneralMode: Simply split a training graph into multiple devices without other extra features.
// kParallelMode: MindSpore's existing auto-parallel feature along with distributed graph splitting feature are
// combined. This is much more complicated than other mode. It is always applied in MoE scenarios.
// kPSMode: Applied when running Parameter Server training.
// kEmbeddingCacheMode: Applied when embedding cache is enabled. Normally used for training models with large embedding
// layer.
enum class DistExecutionMode { kGeneralMode = 0, kParallelMode, kPSMode, kEmbeddingCacheMode, kInvalidMode };
// The operator's label in distributed execution.
constexpr char kOpLabelRankId[] = "rank_id";

View File

@ -27,6 +27,7 @@
#include "mindspore/core/utils/ms_context.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/debug/draw.h"
#include "include/common/utils/parallel_context.h"
#ifdef WITH_BACKEND
#include "ps/ps_context.h"
#endif
@ -39,18 +40,17 @@ bool OperatorLabel::operator==(const OperatorLabel &label) const { return to_str
bool OperatorLabel::operator!=(const OperatorLabel &label) const { return !(*this == label); }
bool OperatorLabel::LooseEqual(const OperatorLabel &label) const {
auto mode = distributed::DistExecutionMode::kPSMode;
bool OperatorLabel::LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const {
if (kLabelMatchingFuncMap.count(mode) == 0) {
MS_LOG(ERROR) << "The mode " << mode << " is invalid.";
return false;
MS_LOG(DEBUG) << "The mode " << mode << " does not need LooseEqual.";
return to_string() == label.to_string();
}
return kLabelMatchingFuncMap.at(mode)(label, *this);
}
std::string OperatorLabel::to_string() const { return std::to_string(rank_id) + "_" + ms_role; }
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node) {
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node, bool use_fake_shape) {
tensor::TensorPtr fake_tensor = nullptr;
if (use_origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
@ -63,15 +63,26 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
}
MS_EXCEPTION_IF_NULL(origin_abstract);
fake_tensor = std::make_shared<tensor::Tensor>(origin_abstract->element()->BuildType()->type_id(),
origin_abstract->shape()->shape());
MS_EXCEPTION_IF_NULL(fake_tensor);
fake_tensor->set_base_shape(origin_abstract->shape()->Clone());
auto element = origin_abstract->element();
MS_EXCEPTION_IF_NULL(element);
auto build_type = element->BuildType();
MS_EXCEPTION_IF_NULL(build_type);
auto type_id = build_type->type_id();
if (use_fake_shape) {
// Assign send's output shape as {1};
ShapeVector fake_shape = {kSizeOne};
fake_tensor = std::make_shared<tensor::Tensor>(type_id, fake_shape);
} else {
auto shape = origin_abstract->shape();
MS_EXCEPTION_IF_NULL(shape);
fake_tensor = std::make_shared<tensor::Tensor>(type_id, shape->shape());
fake_tensor->set_base_shape(shape->Clone());
}
} else {
fake_tensor = std::make_shared<tensor::Tensor>(1.0);
MS_EXCEPTION_IF_NULL(fake_tensor);
}
MS_EXCEPTION_IF_NULL(fake_tensor);
auto fake_value = NewValueNode(fake_tensor);
MS_EXCEPTION_IF_NULL(fake_value);
fake_value->set_abstract(fake_tensor->ToAbstract());
@ -249,8 +260,8 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
if (src_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, src_node->cast<CNodePtr>()) &&
common::AnfAlgo::HasNodeAttr(kAttrParameterInputIndex, src_node->cast<CNodePtr>())) {
int64_t parameter_index = common::AnfAlgo::GetNodeAttr<int64_t>(src_node, kAttrParameterInputIndex);
auto kernel_with_index =
common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), parameter_index), 0);
auto kernel_with_index = common::AnfAlgo::VisitKernel(
common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), parameter_index), kIndex0);
auto param_node = kernel_with_index.first;
recv_inputs.push_back(param_node);
@ -264,7 +275,8 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
recv_node_abs = param_node->abstract();
} else {
auto mock_value = CreateFakeValueNode(true, src_node);
// Use the same shape as origin node's.
auto mock_value = CreateFakeValueNode(true, src_node, false);
MS_EXCEPTION_IF_NULL(mock_value);
recv_inputs.push_back(mock_value);
recv_node_abs = src_node->abstract();
@ -320,6 +332,86 @@ bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &inp
return std::count(all_inputs.begin(), all_inputs.end(), input) != 0;
}
distributed::DistExecutionMode GenerateStrategy() {
distributed::DistExecutionMode strategy;
bool enable_ps = false;
bool enable_embedding_cache = false;
#ifdef WITH_BACKEND
enable_ps = ps::PSContext::instance()->is_ps_mode();
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
#endif
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
bool using_parallel = (parallel_mode != parallel::kStandalone) ? true : false;
// The conditions' priority is: EmbeddingCache > Parameter Server > General.
if (enable_embedding_cache) {
strategy = distributed::DistExecutionMode::kEmbeddingCacheMode;
} else if (enable_ps) {
strategy = distributed::DistExecutionMode::kPSMode;
} else if (using_parallel) {
strategy = distributed::DistExecutionMode::kParallelMode;
} else {
strategy = distributed::DistExecutionMode::kGeneralMode;
}
return strategy;
}
void TransformPrimAttrToAttr(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(kIndex0));
MS_EXCEPTION_IF_NULL(prim);
if (cnode->HasPrimalAttr(distributed::kOpLabelRankId)) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'rank_id'.";
prim->set_attr(distributed::kOpLabelRankId, cnode->GetPrimalAttr(distributed::kOpLabelRankId));
}
if (cnode->HasPrimalAttr(distributed::kOpLabelRole)) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'ms_role'.";
prim->set_attr(distributed::kOpLabelRole, cnode->GetPrimalAttr(distributed::kOpLabelRole));
}
}
bool NodeHasLabel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
bool has_label = false;
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim_node = cnode->input(0);
MS_EXCEPTION_IF_NULL(prim_node);
// As long as the node has 'ms_role' and 'rank_id' attributes, we consider this node has label regardless the value of
// these two attributes.
if (IsValueNode<Primitive>(prim_node)) {
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
has_label = true;
}
} else {
// Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
has_label = true;
}
}
return has_label;
}
bool GraphHasLabel(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph->get_return());
// If one node has label, this graph has label. Thus it needs to be split.
for (const auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (NodeHasLabel(node)) {
return true;
}
}
return false;
}
void ParameterServerMode::PreBuildDistributedGraph() {
MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode.";
MS_EXCEPTION_IF_NULL(node_labels_);
@ -773,6 +865,8 @@ FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges(
InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index};
FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0,
std::get<2>(node_pair), std::get<3>(node_pair));
std::vector<FusedInterProcessOpPair> pair_list = {fused_op_pair};
results.insert(std::make_pair(edge_with_index, pair_list));
}
}
return results;
@ -896,12 +990,9 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c
this_process_label_({rank_id, role}),
node_labels_{},
need_fuse_rpc_nodes_(true) {
bool enable_embedding_cache = false;
#ifdef WITH_BACKEND
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
#endif
mode_ = enable_embedding_cache ? distributed::DistExecutionMode::kEmbeddingCacheMode
: distributed::DistExecutionMode::kPSMode;
// The distributed strategy is not explicitly defined by user. Distributed module generates the distributed strategy
// and default label according to some flags set by other modules.
mode_ = GenerateStrategy();
default_label_ = {0, distributed::kEnvRoleOfWorker};
}
@ -1044,7 +1135,7 @@ void GraphSplitter::DyeGraph() {
}
// If the node's label is the same as this process's, set its label to this_process_label_.
if (this_process_label_.LooseEqual(node_labels_[node])) {
if (this_process_label_.LooseEqual(node_labels_[node], mode_)) {
node_labels_[node] = this_process_label_;
}
});
@ -1059,6 +1150,8 @@ void GraphSplitter::CreateExecutionMode() {
exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
} else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
} else if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
exec_mode_ = std::make_unique<GeneralMode>(func_graph_, &node_labels_, rank_id_, role_);
}
MS_EXCEPTION_IF_NULL(exec_mode_);
}
@ -1170,8 +1263,10 @@ OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
}
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim_node = cnode->input(0);
if (IsValueNode<Primitive>(prim_node)) {
TransformPrimAttrToAttr(cnode);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {

View File

@ -53,7 +53,7 @@ struct OperatorLabel {
// Judge whether the labels are equal but with looser conditions according to different modes. For example, this
// method returns true when comparing the workers in PS mode.
bool LooseEqual(const OperatorLabel &label) const;
bool LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const;
std::string to_string() const;
};
@ -79,8 +79,16 @@ inline bool MatchLabelForPSMode(const OperatorLabel &label1, const OperatorLabel
}
return false;
}
inline bool MatchLabelForParallelMode(const OperatorLabel &label1, const OperatorLabel &label2) {
// When parallel mode is enabled by using MindSpore cluster, processes with the same role has the same label
// regardless of their rank id.
return (label1.ms_role == label2.ms_role);
}
const std::map<distributed::DistExecutionMode, LabelMatchingFunc> kLabelMatchingFuncMap = {
{distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode}};
{distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode},
{distributed::DistExecutionMode::kEmbeddingCacheMode, MatchLabelForPSMode},
{distributed::DistExecutionMode::kParallelMode, MatchLabelForParallelMode}};
// Split graph segment which is generated according to the topo sort of the graph.
struct SplitGraphSegment {
@ -181,7 +189,8 @@ constexpr char kVirtualNode[] = "VirtualNode";
// This method creates a fake tensor. Its type is the same as the origin_node's output if use_origin_node is set
// true.
// Normally it is used to connect the edges for send/recv nodes.
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr);
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr,
bool use_fake_shape = true);
// Create a TupleGetItem node from a node with tuple output.
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
@ -212,6 +221,33 @@ std::map<size_t, size_t> GetRealIndexToSeg(const std::vector<size_t> &split_segm
bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input);
/**
* @description: Generate the distributed strategy according to user configuration.
* @return {distributed::DistExecutionMode}: The distributed strategy enum.
*/
distributed::DistExecutionMode GenerateStrategy();
/**
* @description: Transform primal attributes of cnode to normal attributes.
* @param {CNodePtr} &cnode: The cnode which has the primal attributes.
* @return {void}
*/
void TransformPrimAttrToAttr(const CNodePtr &cnode);
/**
* @description: Judge whether this node has label.
* @param {AnfNodePtr} &node: AnfNode in a func_graph.
* @return {bool}: Whether this node has label.
*/
bool NodeHasLabel(const AnfNodePtr &node);
/**
* @description: Judge whether this graph has any label.
* @param {FuncGraphPtr} &func_graph: The func_graph.
* @return {bool}: Whether this graph has label.
*/
bool GraphHasLabel(const FuncGraphPtr &func_graph);
// Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc.
class DistributedExecutionMode {
public:
@ -334,6 +370,16 @@ class EmbeddingCacheMode : public DistributedExecutionMode {
OperatorLabel GetNodeLabel(const AnfNodePtr &node) const;
};
// Users may want to simply split a training graph into multiple devices without other extra features. GeneralMode is
// for this scenario.
class GeneralMode : public DistributedExecutionMode {
public:
explicit GeneralMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
const std::string &role)
: DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
~GeneralMode() = default;
};
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
// cluster.
class GraphSplitter {

View File

@ -237,6 +237,9 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;
// Whether this process in a MindSpore cluster.
static bool is_cluster_initialized = false;
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_abs, bool clear) {
MS_LOG(DEBUG) << "AbstractAnalyze start";
@ -1379,7 +1382,7 @@ static std::vector<ActionItem> CommonPipeline() {
(void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
if (!multi_graphs && pipeline::GetJitLevel() != "O0") {
if (!is_cluster_initialized && !multi_graphs && pipeline::GetJitLevel() != "O0") {
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}
@ -1420,6 +1423,7 @@ std::vector<ActionItem> GePipeline() {
}
std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
is_cluster_initialized = distributed::cluster::ClusterContext::instance()->initialized();
std::vector<ActionItem> actions;
// If enable compilation cache and the cache is read successfully, only do the backend actions.
if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) {

View File

@ -55,7 +55,7 @@ void RpcSendKernelMod::Init(const CNodePtr &kernel_node) {
}
std::vector<KernelAttr> RpcSendKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true).AddAllOutInRef(true)};
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
return support_list;
}

View File

@ -2159,6 +2159,15 @@ class Cell(Cell_):
params.append(param)
return params
def place(self, role, rank_id):
"""
Set the label for all operators in this cell.
This label tells MindSpore compiler on which process this cell should be launched.
"""
all_ops = self._get_prims_recursively()
for op in all_ops:
op.place(role, rank_id)
def _check_compile_dynamic_shape(self, *inputs):
"""
Check if graph has been compiled with dynamic shape.

View File

@ -384,6 +384,14 @@ class Primitive(Primitive_):
self.add_prim_attr("recompute", mode)
return self
def place(self, role, rank_id):
"""
Set the label for this primitive.
This label tells MindSpore compiler on which process this operator should be launched.
"""
self.add_prim_attr("ms_role", role)
self.add_prim_attr("rank_id", rank_id)
class PrimitiveWithCheck(Primitive):
"""

View File

@ -26,6 +26,7 @@ from mindspore.communication.management import init, get_group_size
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
context.set_ps_context(enable_ssl=False)
init()
context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size())
class Net(nn.Cell):

View File

@ -26,6 +26,7 @@ from mindspore.communication.management import init, get_group_size
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
context.set_ps_context(enable_ssl=False)
init()
context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size())
class Net(nn.Cell):