forked from mindspore-Ecosystem/mindspore
Add general strategy
This commit is contained in:
parent
2695f6c868
commit
0efd0c4b25
|
@ -49,7 +49,19 @@ const std::vector<std::string> kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpd
|
|||
constexpr char kFinalizeMuxRecvActor[] = "FINALIZE_MUX_RECV_ACTOR";
|
||||
|
||||
// The distributed execution mode enum.
|
||||
enum class DistExecutionMode { kPSMode = 0, kEmbeddingCacheMode, kInvalidMode };
|
||||
// For each execution mode, different graph optimization, splitting strategy, device location, etc are applied. For
|
||||
// details please refer to class DistributedExecutionMode and its subclasses.
|
||||
|
||||
// kGeneralMode: Simply split a training graph into multiple devices without other extra features.
|
||||
|
||||
// kParallelMode: MindSpore's existing auto-parallel feature along with distributed graph splitting feature are
|
||||
// combined. This is much more complicated than other mode. It is always applied in MoE scenarios.
|
||||
|
||||
// kPSMode: Applied when running Parameter Server training.
|
||||
|
||||
// kEmbeddingCacheMode: Applied when embedding cache is enabled. Normally used for training models with large embedding
|
||||
// layer.
|
||||
enum class DistExecutionMode { kGeneralMode = 0, kParallelMode, kPSMode, kEmbeddingCacheMode, kInvalidMode };
|
||||
|
||||
// The operator's label in distributed execution.
|
||||
constexpr char kOpLabelRankId[] = "rank_id";
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mindspore/core/utils/ms_context.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/debug/draw.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#ifdef WITH_BACKEND
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
@ -39,18 +40,17 @@ bool OperatorLabel::operator==(const OperatorLabel &label) const { return to_str
|
|||
|
||||
bool OperatorLabel::operator!=(const OperatorLabel &label) const { return !(*this == label); }
|
||||
|
||||
bool OperatorLabel::LooseEqual(const OperatorLabel &label) const {
|
||||
auto mode = distributed::DistExecutionMode::kPSMode;
|
||||
bool OperatorLabel::LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const {
|
||||
if (kLabelMatchingFuncMap.count(mode) == 0) {
|
||||
MS_LOG(ERROR) << "The mode " << mode << " is invalid.";
|
||||
return false;
|
||||
MS_LOG(DEBUG) << "The mode " << mode << " does not need LooseEqual.";
|
||||
return to_string() == label.to_string();
|
||||
}
|
||||
return kLabelMatchingFuncMap.at(mode)(label, *this);
|
||||
}
|
||||
|
||||
std::string OperatorLabel::to_string() const { return std::to_string(rank_id) + "_" + ms_role; }
|
||||
|
||||
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node) {
|
||||
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node, bool use_fake_shape) {
|
||||
tensor::TensorPtr fake_tensor = nullptr;
|
||||
if (use_origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
|
@ -63,15 +63,26 @@ ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_
|
|||
origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(origin_abstract);
|
||||
fake_tensor = std::make_shared<tensor::Tensor>(origin_abstract->element()->BuildType()->type_id(),
|
||||
origin_abstract->shape()->shape());
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
fake_tensor->set_base_shape(origin_abstract->shape()->Clone());
|
||||
auto element = origin_abstract->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto build_type = element->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(build_type);
|
||||
auto type_id = build_type->type_id();
|
||||
if (use_fake_shape) {
|
||||
// Assign send's output shape as {1};
|
||||
ShapeVector fake_shape = {kSizeOne};
|
||||
fake_tensor = std::make_shared<tensor::Tensor>(type_id, fake_shape);
|
||||
} else {
|
||||
auto shape = origin_abstract->shape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
fake_tensor = std::make_shared<tensor::Tensor>(type_id, shape->shape());
|
||||
fake_tensor->set_base_shape(shape->Clone());
|
||||
}
|
||||
} else {
|
||||
fake_tensor = std::make_shared<tensor::Tensor>(1.0);
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
auto fake_value = NewValueNode(fake_tensor);
|
||||
MS_EXCEPTION_IF_NULL(fake_value);
|
||||
fake_value->set_abstract(fake_tensor->ToAbstract());
|
||||
|
@ -249,8 +260,8 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
|
|||
if (src_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, src_node->cast<CNodePtr>()) &&
|
||||
common::AnfAlgo::HasNodeAttr(kAttrParameterInputIndex, src_node->cast<CNodePtr>())) {
|
||||
int64_t parameter_index = common::AnfAlgo::GetNodeAttr<int64_t>(src_node, kAttrParameterInputIndex);
|
||||
auto kernel_with_index =
|
||||
common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), parameter_index), 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(
|
||||
common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), parameter_index), kIndex0);
|
||||
auto param_node = kernel_with_index.first;
|
||||
recv_inputs.push_back(param_node);
|
||||
|
||||
|
@ -264,7 +275,8 @@ CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge
|
|||
|
||||
recv_node_abs = param_node->abstract();
|
||||
} else {
|
||||
auto mock_value = CreateFakeValueNode(true, src_node);
|
||||
// Use the same shape as origin node's.
|
||||
auto mock_value = CreateFakeValueNode(true, src_node, false);
|
||||
MS_EXCEPTION_IF_NULL(mock_value);
|
||||
recv_inputs.push_back(mock_value);
|
||||
recv_node_abs = src_node->abstract();
|
||||
|
@ -320,6 +332,86 @@ bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &inp
|
|||
return std::count(all_inputs.begin(), all_inputs.end(), input) != 0;
|
||||
}
|
||||
|
||||
distributed::DistExecutionMode GenerateStrategy() {
|
||||
distributed::DistExecutionMode strategy;
|
||||
bool enable_ps = false;
|
||||
bool enable_embedding_cache = false;
|
||||
#ifdef WITH_BACKEND
|
||||
enable_ps = ps::PSContext::instance()->is_ps_mode();
|
||||
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
|
||||
#endif
|
||||
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
bool using_parallel = (parallel_mode != parallel::kStandalone) ? true : false;
|
||||
// The conditions' priority is: EmbeddingCache > Parameter Server > General.
|
||||
if (enable_embedding_cache) {
|
||||
strategy = distributed::DistExecutionMode::kEmbeddingCacheMode;
|
||||
} else if (enable_ps) {
|
||||
strategy = distributed::DistExecutionMode::kPSMode;
|
||||
} else if (using_parallel) {
|
||||
strategy = distributed::DistExecutionMode::kParallelMode;
|
||||
} else {
|
||||
strategy = distributed::DistExecutionMode::kGeneralMode;
|
||||
}
|
||||
return strategy;
|
||||
}
|
||||
|
||||
void TransformPrimAttrToAttr(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(kIndex0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (cnode->HasPrimalAttr(distributed::kOpLabelRankId)) {
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'rank_id'.";
|
||||
prim->set_attr(distributed::kOpLabelRankId, cnode->GetPrimalAttr(distributed::kOpLabelRankId));
|
||||
}
|
||||
if (cnode->HasPrimalAttr(distributed::kOpLabelRole)) {
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'ms_role'.";
|
||||
prim->set_attr(distributed::kOpLabelRole, cnode->GetPrimalAttr(distributed::kOpLabelRole));
|
||||
}
|
||||
}
|
||||
|
||||
bool NodeHasLabel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool has_label = false;
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim_node = cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(prim_node);
|
||||
|
||||
// As long as the node has 'ms_role' and 'rank_id' attributes, we consider this node has label regardless the value of
|
||||
// these two attributes.
|
||||
if (IsValueNode<Primitive>(prim_node)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
|
||||
has_label = true;
|
||||
}
|
||||
} else {
|
||||
// Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
|
||||
if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
|
||||
has_label = true;
|
||||
}
|
||||
}
|
||||
return has_label;
|
||||
}
|
||||
|
||||
bool GraphHasLabel(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph->get_return());
|
||||
// If one node has label, this graph has label. Thus it needs to be split.
|
||||
for (const auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (NodeHasLabel(node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ParameterServerMode::PreBuildDistributedGraph() {
|
||||
MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode.";
|
||||
MS_EXCEPTION_IF_NULL(node_labels_);
|
||||
|
@ -773,6 +865,8 @@ FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges(
|
|||
InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index};
|
||||
FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0,
|
||||
std::get<2>(node_pair), std::get<3>(node_pair));
|
||||
std::vector<FusedInterProcessOpPair> pair_list = {fused_op_pair};
|
||||
results.insert(std::make_pair(edge_with_index, pair_list));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
|
@ -896,12 +990,9 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c
|
|||
this_process_label_({rank_id, role}),
|
||||
node_labels_{},
|
||||
need_fuse_rpc_nodes_(true) {
|
||||
bool enable_embedding_cache = false;
|
||||
#ifdef WITH_BACKEND
|
||||
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
|
||||
#endif
|
||||
mode_ = enable_embedding_cache ? distributed::DistExecutionMode::kEmbeddingCacheMode
|
||||
: distributed::DistExecutionMode::kPSMode;
|
||||
// The distributed strategy is not explicitly defined by user. Distributed module generates the distributed strategy
|
||||
// and default label according to some flags set by other modules.
|
||||
mode_ = GenerateStrategy();
|
||||
default_label_ = {0, distributed::kEnvRoleOfWorker};
|
||||
}
|
||||
|
||||
|
@ -1044,7 +1135,7 @@ void GraphSplitter::DyeGraph() {
|
|||
}
|
||||
|
||||
// If the node's label is the same as this process's, set its label to this_process_label_.
|
||||
if (this_process_label_.LooseEqual(node_labels_[node])) {
|
||||
if (this_process_label_.LooseEqual(node_labels_[node], mode_)) {
|
||||
node_labels_[node] = this_process_label_;
|
||||
}
|
||||
});
|
||||
|
@ -1059,6 +1150,8 @@ void GraphSplitter::CreateExecutionMode() {
|
|||
exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
|
||||
} else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
|
||||
exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
|
||||
} else if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
|
||||
exec_mode_ = std::make_unique<GeneralMode>(func_graph_, &node_labels_, rank_id_, role_);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(exec_mode_);
|
||||
}
|
||||
|
@ -1170,8 +1263,10 @@ OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) {
|
|||
MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim_node = cnode->input(0);
|
||||
if (IsValueNode<Primitive>(prim_node)) {
|
||||
TransformPrimAttrToAttr(cnode);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
|
||||
|
|
|
@ -53,7 +53,7 @@ struct OperatorLabel {
|
|||
|
||||
// Judge whether the labels are equal but with looser conditions according to different modes. For example, this
|
||||
// method returns true when comparing the workers in PS mode.
|
||||
bool LooseEqual(const OperatorLabel &label) const;
|
||||
bool LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const;
|
||||
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
@ -79,8 +79,16 @@ inline bool MatchLabelForPSMode(const OperatorLabel &label1, const OperatorLabel
|
|||
}
|
||||
return false;
|
||||
}
|
||||
inline bool MatchLabelForParallelMode(const OperatorLabel &label1, const OperatorLabel &label2) {
|
||||
// When parallel mode is enabled by using MindSpore cluster, processes with the same role has the same label
|
||||
// regardless of their rank id.
|
||||
return (label1.ms_role == label2.ms_role);
|
||||
}
|
||||
|
||||
const std::map<distributed::DistExecutionMode, LabelMatchingFunc> kLabelMatchingFuncMap = {
|
||||
{distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode}};
|
||||
{distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode},
|
||||
{distributed::DistExecutionMode::kEmbeddingCacheMode, MatchLabelForPSMode},
|
||||
{distributed::DistExecutionMode::kParallelMode, MatchLabelForParallelMode}};
|
||||
|
||||
// Split graph segment which is generated according to the topo sort of the graph.
|
||||
struct SplitGraphSegment {
|
||||
|
@ -181,7 +189,8 @@ constexpr char kVirtualNode[] = "VirtualNode";
|
|||
// This method creates a fake tensor. Its type is the same as the origin_node's output if use_origin_node is set
|
||||
// true.
|
||||
// Normally it is used to connect the edges for send/recv nodes.
|
||||
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr);
|
||||
ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr,
|
||||
bool use_fake_shape = true);
|
||||
|
||||
// Create a TupleGetItem node from a node with tuple output.
|
||||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
|
||||
|
@ -212,6 +221,33 @@ std::map<size_t, size_t> GetRealIndexToSeg(const std::vector<size_t> &split_segm
|
|||
|
||||
bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input);
|
||||
|
||||
/**
|
||||
* @description: Generate the distributed strategy according to user configuration.
|
||||
* @return {distributed::DistExecutionMode}: The distributed strategy enum.
|
||||
*/
|
||||
distributed::DistExecutionMode GenerateStrategy();
|
||||
|
||||
/**
|
||||
* @description: Transform primal attributes of cnode to normal attributes.
|
||||
* @param {CNodePtr} &cnode: The cnode which has the primal attributes.
|
||||
* @return {void}
|
||||
*/
|
||||
void TransformPrimAttrToAttr(const CNodePtr &cnode);
|
||||
|
||||
/**
|
||||
* @description: Judge whether this node has label.
|
||||
* @param {AnfNodePtr} &node: AnfNode in a func_graph.
|
||||
* @return {bool}: Whether this node has label.
|
||||
*/
|
||||
bool NodeHasLabel(const AnfNodePtr &node);
|
||||
|
||||
/**
|
||||
* @description: Judge whether this graph has any label.
|
||||
* @param {FuncGraphPtr} &func_graph: The func_graph.
|
||||
* @return {bool}: Whether this graph has label.
|
||||
*/
|
||||
bool GraphHasLabel(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc.
|
||||
class DistributedExecutionMode {
|
||||
public:
|
||||
|
@ -334,6 +370,16 @@ class EmbeddingCacheMode : public DistributedExecutionMode {
|
|||
OperatorLabel GetNodeLabel(const AnfNodePtr &node) const;
|
||||
};
|
||||
|
||||
// Users may want to simply split a training graph into multiple devices without other extra features. GeneralMode is
|
||||
// for this scenario.
|
||||
class GeneralMode : public DistributedExecutionMode {
|
||||
public:
|
||||
explicit GeneralMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
|
||||
const std::string &role)
|
||||
: DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
|
||||
~GeneralMode() = default;
|
||||
};
|
||||
|
||||
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
|
||||
// cluster.
|
||||
class GraphSplitter {
|
||||
|
|
|
@ -237,6 +237,9 @@ using CompileGraphs = compile::CompileGraphs;
|
|||
using abstract::AnalysisResult;
|
||||
using mindspore::abstract::AnalysisContextPtr;
|
||||
|
||||
// Whether this process in a MindSpore cluster.
|
||||
static bool is_cluster_initialized = false;
|
||||
|
||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const FuncGraphPtr &func_graph,
|
||||
const abstract::AbstractBasePtrList &args_abs, bool clear) {
|
||||
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
||||
|
@ -1379,7 +1382,7 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
|
||||
|
||||
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
|
||||
if (!multi_graphs && pipeline::GetJitLevel() != "O0") {
|
||||
if (!is_cluster_initialized && !multi_graphs && pipeline::GetJitLevel() != "O0") {
|
||||
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||
}
|
||||
|
||||
|
@ -1420,6 +1423,7 @@ std::vector<ActionItem> GePipeline() {
|
|||
}
|
||||
|
||||
std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
|
||||
is_cluster_initialized = distributed::cluster::ClusterContext::instance()->initialized();
|
||||
std::vector<ActionItem> actions;
|
||||
// If enable compilation cache and the cache is read successfully, only do the backend actions.
|
||||
if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) {
|
||||
|
|
|
@ -55,7 +55,7 @@ void RpcSendKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
std::vector<KernelAttr> RpcSendKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true).AddAllOutInRef(true)};
|
||||
std::vector<KernelAttr> support_list = {KernelAttr().AddSkipCheckAttr(true)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
|
|
|
@ -2159,6 +2159,15 @@ class Cell(Cell_):
|
|||
params.append(param)
|
||||
return params
|
||||
|
||||
def place(self, role, rank_id):
|
||||
"""
|
||||
Set the label for all operators in this cell.
|
||||
This label tells MindSpore compiler on which process this cell should be launched.
|
||||
"""
|
||||
all_ops = self._get_prims_recursively()
|
||||
for op in all_ops:
|
||||
op.place(role, rank_id)
|
||||
|
||||
def _check_compile_dynamic_shape(self, *inputs):
|
||||
"""
|
||||
Check if graph has been compiled with dynamic shape.
|
||||
|
|
|
@ -384,6 +384,14 @@ class Primitive(Primitive_):
|
|||
self.add_prim_attr("recompute", mode)
|
||||
return self
|
||||
|
||||
def place(self, role, rank_id):
|
||||
"""
|
||||
Set the label for this primitive.
|
||||
This label tells MindSpore compiler on which process this operator should be launched.
|
||||
"""
|
||||
self.add_prim_attr("ms_role", role)
|
||||
self.add_prim_attr("rank_id", rank_id)
|
||||
|
||||
|
||||
class PrimitiveWithCheck(Primitive):
|
||||
"""
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.communication.management import init, get_group_size
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
context.set_ps_context(enable_ssl=False)
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size())
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.communication.management import init, get_group_size
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
context.set_ps_context(enable_ssl=False)
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size())
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue