!30469 add shard function to support part of the graph executed in auto_parallel under pynative mode

Merge pull request !30469 from wangjun/0223_pp
This commit is contained in:
i-robot 2022-02-25 06:52:24 +00:00 committed by Gitee
commit 0341d96dd6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
26 changed files with 650 additions and 49 deletions

View File

@ -341,6 +341,14 @@
.. note:: 仅在全自动并行(AUTO_PARALLEL)模式下生效。
.. py:method:: shard(in_axes, out_axes, device="Ascend", level=0)
指定输入/输出tensor的分布策略其余算子的策略推导得到。在PyNative模式下可以利用此方法指定某个cell以图模式进行分布式执行。
in_axes/out_axes需要为元组类型其中的每一个元素指定对应的输入/输出的tensor分布策略其类型需要为元组
可参考:`mindspore.ops.Primitive.shard`的描述也可以设置为None会默认以数据并行执行。
.. note:: 需设置为Pyative模式并且全自动并行(AUTO_PARALLEL)同时search mode为sharding_propagation或半自动并行SEMI_AUTO_PARALLEL)。
.. py:method:: set_grad(requires_grad=True)
Cell的梯度设置。在PyNative模式下该参数指定Cell是否需要梯度。如果为True则在执行正向网络时将生成需要计算梯度的反向网络。

View File

@ -19,6 +19,7 @@
#include <vector>
#include <map>
#include "frontend/parallel/context.h"
#include "backend/graph_compiler/transform.h"
#include "backend/common/session/session_factory.h"
#include "runtime/op_builder/op_lazy_builder.h"
@ -451,10 +452,12 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(context_ptr);
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
real_execution_mode_ = ms_execution_mode_;
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
auto is_parallel = (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL);
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
if (ms_execution_mode_ == kPynativeMode &&
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1)) {
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1) && !is_parallel) {
real_execution_mode_ = kGraphMode;
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
pipeline::SetRunMode(func_graph, this);
@ -891,7 +894,7 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
// Save grad node to Bucket
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
}
}

View File

@ -126,7 +126,7 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunOutputReplace(const CNodePtr &forwa
const FuncGraphPtr &fprop_graph,
const CNodePtr &cnode_morph) {
MS_EXCEPTION_IF_NULL(cnode_morph);
if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient)) {
if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode_morph, prim::kPrimMirror)) {
return {};
}
// Use manager to get the link relation among nodes.
@ -176,7 +176,7 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bp
MS_EXCEPTION_IF_NULL(input_node);
// Parameter, ValueNode and StopGradient CNode no need to replace.
if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
IsPrimitiveCNode(input_node, prim::kPrimStopGradient)) {
IsPrimitiveCNode(input_node, prim::kPrimStopGradient) || IsPrimitiveCNode(input_node, prim::kPrimMirror)) {
continue;
}
// Replace forward input node by its output value.
@ -187,6 +187,9 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bp
MS_EXCEPTION_IF_NULL(output_vnode_i);
output_vnode_i->set_has_new_value(true);
manager->Replace(paras[i], output_vnode_i);
if (IsPrimitiveCNode(cnode_i, prim::kPrimLoad)) {
para_ref_size += 1;
}
MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
// Save forward input node when it used in bprop graph.
if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {

View File

@ -61,10 +61,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ = MakeSubstitution(
std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, prim::kPrimCellBackwardHook,
prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
special_op_eliminate_ =
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimCellBackwardHook, prim::kPrimPrintShapeType, prim::kPrimGetRefValue});
ad_related_special_op_eliminate_ =
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "ad_related_special_op_eliminate",
{prim::kPrimMirror, prim::kPrimVirtualDiv});
pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
zero_like_fill_zero_ =
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);

View File

@ -35,6 +35,7 @@ class OptimizeIRPassLib {
SubstitutionPtr arithmetic_simplify_;
SubstitutionPtr arithmetic_simplify2_;
SubstitutionPtr special_op_eliminate_;
SubstitutionPtr ad_related_special_op_eliminate_;
SubstitutionPtr zero_like_fill_zero_;
SubstitutionPtr adjust_all_reduce_mul_add_;
SubstitutionPtr float_depend_g_call_;

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/irpass/shard_eliminate.h"
namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
AnfNodePtr ExpandShard(const CNodePtr &node) {
auto vnode = node->input(1)->cast<ValueNodePtr>();
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
MS_EXCEPTION_IF_NULL(func_graph);
return NewValueNode(func_graph);
}
} // namespace internal
bool ExpandShardPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
GetShardPrim(func_graph);
bool change = false;
auto manager = optimizer->manager();
for (auto shard_node : shard_nodes_) {
auto expanded_shard = internal::ExpandShard(shard_node);
manager->Replace(shard_node, expanded_shard);
change = true;
}
return change;
}
void ExpandShardPrim::GetShardPrim(const FuncGraphPtr &func_graph) {
shard_nodes_.clear();
AnfNodePtr ret = func_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
shard_nodes_.push_back(node->cast<CNodePtr>());
}
}
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SHARD_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SHARD_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/anf_visitor.h"
#include "utils/ms_utils.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
class ExpandShardPrim {
public:
ExpandShardPrim() = default;
virtual ~ExpandShardPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
void GetShardPrim(const FuncGraphPtr &func_graph);
private:
std::vector<CNodePtr> shard_nodes_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SHARD_ELIMINATE_H_

View File

@ -347,6 +347,7 @@ constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax";
constexpr char SIMPLE_MEAN[] = "SimpleMean";
constexpr char FLATTEN[] = "Flatten";
constexpr char J[] = "J";
constexpr char SHARD[] = "Shard";
constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info";
constexpr char COS[] = "Cos";
constexpr char ACOS[] = "ACos";

View File

@ -935,10 +935,30 @@ void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group,
}
}
FuncGraphPtr PynativeParallelGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
FuncGraphPtr real_graph = root;
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto expect_shard_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (expect_shard_prim->name() != SHARD) {
continue;
}
real_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
}
return real_graph;
}
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
vector<std::string> last_forward_node_ids;
vector<size_t> last_indexs;
FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs);
auto real_graph = PynativeParallelGraph(root, all_nodes);
FindLastNodesUniqueId(real_graph, &last_forward_node_ids, &last_indexs);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
for (auto &node : all_nodes) {
// here insert virtualoutput node
@ -2189,9 +2209,9 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
auto tuple_index = GetTupleGetItemIndex(cnode);
auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
if (!layout_ptr) {
MS_LOG(EXCEPTION)
<< " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node "
"before tuple_getitem!";
MS_LOG(EXCEPTION) << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a "
"parallel care node "
"before tuple_getitem!";
}
return layout_ptr;
}
@ -2485,8 +2505,8 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (expect_j_prim->name() != J) {
auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (expect_prim->name() != J && expect_prim->name() != SHARD) {
continue;
}
if (IsValueNode<FuncGraph>(cnode->input(1))) {
@ -2513,6 +2533,12 @@ void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
}
}
bool IsPynativeParallel() {
auto parallel_mode = ParallelContext::GetInstance()->parallel_mode();
auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
return (execution_mode == kPynativeMode) && (parallel_mode == SEMI_AUTO_PARALLEL || parallel_mode == AUTO_PARALLEL);
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
@ -2607,7 +2633,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
}
// insert backward ops
if (has_backward) {
if (has_backward || IsPynativeParallel()) {
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
}
@ -3075,8 +3101,9 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
"the input parallel strategy when using context.set_auto_parallel_context(dataset_strategy)"
" to configure the input strategy.";
}
return (!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
current_stage == split_stage_num - 1);
return ((!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
current_stage == split_stage_num - 1) ||
IsPynativeParallel());
}
static void HandleGroupInfo(const FuncGraphPtr &root) {

View File

@ -737,6 +737,17 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
return true;
}
bool EliminateAdRelatedSpecialOpNode(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
}
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null.";
}
return EliminateAdRelatedSpecialOpOptPass(res);
}
bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
for (const auto &node : all_nodes) {
if (!node->isa<CNode>()) {
@ -1395,6 +1406,9 @@ std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
// eliminate forward cnode for grad graph
(void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
// eliminate the virtual mirror node
(void)actions.emplace_back(std::make_pair("eliminate_ad_related_special_op_node", EliminateAdRelatedSpecialOpNode));
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
}

View File

@ -49,6 +49,8 @@
#include "frontend/optimizer/irpass/branch_culling.h"
#include "frontend/optimizer/irpass/meta_fg_eliminate.h"
#include "frontend/optimizer/irpass/ge_specialized_prepare.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/shard_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
@ -188,6 +190,7 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
irpass.reshape_eliminate_,
irpass.switch_simplify_,
irpass.addn_zero_filter_,
irpass.ad_related_special_op_eliminate_,
});
opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
OptPassGroupMap map({
@ -366,6 +369,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
{"virtual_dataset", virtual_dataset},
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
{"shard", opt::OptPassConfig(opt::irpass::ExpandShardPrim())},
{"meta_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaFg())},
{"after_resolve", after_resolve_pass},
{"a_after_grad", a_after_grad},
@ -723,6 +727,21 @@ bool PynativeOptPass(const ResourcePtr &res) {
return true;
}
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &res) {
auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
opt::irpass::OptimizeIRPassLib irpass;
opt::OptPassConfig ad_related_special_op_eliminate = opt::OptPassConfig({
irpass.ad_related_special_op_eliminate_,
});
OptPassGroupMap map({
{"ad_related_special_op_eliminate", ad_related_special_op_eliminate},
});
auto ad_related_special_op_eliminate_opt = opt::Optimizer::MakeOptimizer("ad_related_special_op_eliminate", res, map);
(void)ad_related_special_op_eliminate_opt->step(func_graph, false);
return true;
}
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(func_graph->manager());

View File

@ -47,6 +47,7 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res);
bool InferenceOptPreparePass(const ResourcePtr &res);
void ReclaimOptimizer();
bool PynativeOptPass(const ResourcePtr &res);
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &res);
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph);
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);

View File

@ -137,12 +137,12 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
FuncGraphPtr fun_graph = nullptr;
if (!root->has_flag(mindspore::parallel::TRAINING)) {
graph_sets.insert(root);
}
if (expect_j_prim->name() == mindspore::parallel::J) {
if (expect_prim->name() == mindspore::parallel::J || expect_prim->name() == mindspore::parallel::SHARD) {
if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
} else {
@ -192,6 +192,262 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
}
}
void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<AnfNodePtr> &nodes, const int64_t device_num,
std::vector<std::vector<int64_t>> *default_strategy) {
auto strategies = axes->value()->cast<ValueTuplePtr>()->value();
size_t i = 0;
for (auto &strategy : strategies) {
auto node = nodes[i];
if (strategy->isa<None>()) {
auto node_size = AnfAlgo::GetOutputInferShape(node, 0).size();
std::vector<int64_t> current_d_strategy(node_size, 1);
if (node_size >= 1) {
current_d_strategy[0] = device_num;
}
default_strategy->push_back(current_d_strategy);
} else {
auto current_strategy = GetValue<std::vector<int64_t>>(strategy);
default_strategy->push_back(current_strategy);
}
i += 1;
}
}
bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, size_t *axes_size) {
auto strategies = axes->value()->cast<ValueTuplePtr>()->value();
for (auto &strategy : strategies) {
*axes_size += 1;
if (strategy->isa<None>()) {
*need_default_strategy = true;
continue;
}
if (!strategy->isa<ValueTuple>()) {
return false;
}
auto elements = strategy->cast<ValueTuplePtr>()->value();
for (auto &element : elements) {
if (!element->isa<Int64Imm>()) {
return false;
}
}
}
return true;
}
void HandleStrategyForOneHot(std::vector<ValuePtr> *strategy) {
// onehot needs to set layout for output, modify the strategy with an additional dimension
auto input_strategy = GetValue<std::vector<int64_t>>(strategy->at(0));
input_strategy.push_back(1);
strategy->at(0) = MakeValue(input_strategy);
}
void HandleStrategyForMatMul(std::vector<ValuePtr> *strategy, const CNodePtr &cnode) {
// handle strategy for matmul to deal with corresponding dimension
auto left_matrix_strategy = GetValue<std::vector<int64_t>>(strategy->at(0));
auto right_matrix_strategy = GetValue<std::vector<int64_t>>(strategy->at(1));
auto index_a = left_matrix_strategy.size() - 1;
auto index_b = index_a - 1;
auto attrs = GetCNodePrimitive(cnode)->attrs();
bool transpose_a = attrs[parallel::TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[parallel::TRANSPOSE_B]->cast<BoolImmPtr>()->value();
if (transpose_a) {
index_a -= 1;
}
if (transpose_b) {
index_b += 1;
}
if (left_matrix_strategy[index_a] != right_matrix_strategy[index_b]) {
if (left_matrix_strategy[index_a] == 1) {
left_matrix_strategy[index_a] = right_matrix_strategy[index_b];
} else {
right_matrix_strategy[index_b] = left_matrix_strategy[index_a];
}
strategy->at(0) = MakeValue(left_matrix_strategy);
strategy->at(1) = MakeValue(right_matrix_strategy);
}
}
void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *input_nodes) {
auto parameters = func_graph->parameters();
for (auto &parameter : parameters) {
if (parameter->cast<ParameterPtr>()->name() == "u" || parameter->cast<ParameterPtr>()->name() == "io") {
continue;
}
input_nodes->push_back(parameter);
}
}
void GetOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_nodes) {
auto return_node = func_graph->get_return();
auto real_return_node = return_node->cast<CNodePtr>()->input(1);
while (IsPrimitiveCNode(real_return_node, prim::kPrimDepend)) {
real_return_node = real_return_node->cast<CNodePtr>()->input(1);
}
if (!IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
output_nodes->push_back(real_return_node);
} else {
auto cnode = real_return_node->cast<CNodePtr>();
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
output_nodes->push_back(cnode->input(i));
}
}
}
bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const int64_t &device_num) {
for (size_t i = 0; i < strategies.size(); ++i) {
auto strategy = strategies[i];
int64_t required_num = 1;
std::for_each(strategy.begin(), strategy.end(), [&](int64_t const &data) { required_num *= data; });
if (required_num > device_num) {
MS_LOG(ERROR) << "required device number: " << required_num
<< " is larger than available device number: " << device_num << " at index: " << i;
return false;
}
if (device_num % required_num != 0) {
MS_LOG(ERROR) << "required device number: " << required_num
<< " is not divisible by device number: " << device_num << " at index: " << i;
return false;
}
}
return true;
}
void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, const int64_t &device_num) {
auto out_axes_tuple = out_axes->cast<ValueNodePtr>();
bool need_default_strategy = false;
size_t out_axes_size = 0;
if (!IsValueNode<ValueTuple>(out_axes_tuple) ||
!CheckLayout(out_axes_tuple, &need_default_strategy, &out_axes_size)) {
MS_LOG(EXCEPTION) << "out_axes should be a two-dimension tuple";
}
std::vector<AnfNodePtr> output_nodes;
GetOutputNodes(func_graph, &output_nodes);
if (output_nodes.size() != out_axes_size) {
MS_LOG(EXCEPTION) << "Output number: " << output_nodes.size()
<< " is not equal to out_axes number: " << out_axes_size;
}
std::vector<std::vector<int64_t>> output_strategy;
if (need_default_strategy) {
GenerateDefaultStrategy(out_axes_tuple, output_nodes, device_num, &output_strategy);
} else {
output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_axes_tuple->value());
}
MS_LOG(WARNING) << "The output strategy will be overwritten as data-parallel";
for (size_t i = 0; i < output_nodes.size(); ++i) {
auto node = output_nodes[i];
auto output_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (output_shape.size() != output_strategy[i].size()) {
MS_LOG(EXCEPTION) << "Output dimension: " << output_shape.size()
<< " is not equal to out_axes dimension: " << output_strategy[i].size() << " at index " << i;
}
std::vector<ValuePtr> elements;
elements.push_back(MakeValue(output_strategy[i]));
auto prim = GetCNodePrimitive(node);
auto attrs_temp = prim->attrs();
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
attrs_temp[parallel::OUT_STRATEGY] = strategy;
(void)prim->SetAttrs(attrs_temp);
}
}
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, const int64_t &device_num) {
auto in_axes_tuple = in_axes->cast<ValueNodePtr>();
bool need_default_strategy = false;
size_t in_axes_size = 0;
if (!IsValueNode<ValueTuple>(in_axes_tuple) || !CheckLayout(in_axes_tuple, &need_default_strategy, &in_axes_size)) {
MS_LOG(EXCEPTION) << "in_axes should be a two-dimension tuple";
}
std::vector<AnfNodePtr> input_nodes;
GetInputNodes(func_graph, &input_nodes);
if (input_nodes.size() != in_axes_size) {
MS_LOG(EXCEPTION) << "Input numbers: " << input_nodes.size()
<< " is not equal to in_axes numbers: " << in_axes_size;
}
std::vector<std::vector<int64_t>> input_strategy;
if (need_default_strategy) {
GenerateDefaultStrategy(in_axes_tuple, input_nodes, device_num, &input_strategy);
} else {
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_axes_tuple->value());
}
if (!CheckDeviceNum(input_strategy, device_num)) {
MS_LOG(EXCEPTION) << "check device number failed";
}
std::set<CNodePtr> concerned_nodes;
FuncGraphManagerPtr manager = func_graph->manager();
auto parameters = func_graph->parameters();
for (size_t i = 0; i < parameters.size(); ++i) {
auto parameter = parameters[i];
if (parameter->cast<ParameterPtr>()->name() == "u" || parameter->cast<ParameterPtr>()->name() == "io") {
continue;
}
auto output_shape = AnfAlgo::GetOutputInferShape(parameter, 0);
if (output_shape.size() != input_strategy[i].size()) {
MS_LOG(EXCEPTION) << "Input dimension: " << output_shape.size()
<< " is not equal to in_axes dimension: " << input_strategy[i].size() << " at index " << i;
}
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
for (auto &param_pair : param_sub_set) {
CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
concerned_nodes.insert(param_cnode);
}
}
for (auto &cnode : concerned_nodes) {
auto current_inputs = cnode->inputs();
std::vector<ValuePtr> elements;
for (size_t i = 1; i < current_inputs.size(); ++i) {
auto current_input = current_inputs[i];
if (current_input->isa<ValueNode>()) {
auto current_value = current_input->cast<ValueNodePtr>()->value();
if (!current_value->isa<mindspore::tensor::Tensor>()) {
continue;
}
}
auto iter = std::find(parameters.begin(), parameters.end(), current_input);
if (iter != parameters.end()) {
elements.push_back(MakeValue(input_strategy[iter - parameters.begin()]));
} else {
auto shape = current_input->Shape()->cast<abstract::ShapePtr>();
auto dimension = shape->shape().size();
std::vector<int64_t> default_strategy(dimension, 1);
elements.push_back(MakeValue(default_strategy));
}
}
if (IsPrimitiveCNode(cnode, prim::kPrimMatMul) || IsPrimitiveCNode(cnode, prim::kPrimBatchMatMul)) {
HandleStrategyForMatMul(&elements, cnode);
}
if (IsPrimitiveCNode(cnode, prim::kPrimOneHot)) {
HandleStrategyForOneHot(&elements);
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
auto attrs_temp = prim->attrs();
attrs_temp[parallel::IN_STRATEGY] = strategy;
(void)prim->SetAttrs(attrs_temp);
}
}
void SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const int64_t &device_num) {
root->set_flag("training", true);
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
root->set_flag("auto_parallel", true);
auto cnode = node->cast<CNodePtr>();
auto vnode = cnode->input(1)->cast<ValueNodePtr>();
auto in_axes = cnode->input(2);
auto out_axes = cnode->input(3);
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
MS_EXCEPTION_IF_NULL(func_graph);
SetInputLayout(func_graph, in_axes, device_num);
SetOutputLayout(func_graph, out_axes, device_num);
}
}
}
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
@ -205,11 +461,22 @@ bool PipelineSplit(const ResourcePtr &res) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
return true;
}
auto manager = res->manager();
auto root = res->func_graph();
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
if ((execution_mode == kPynativeMode) &&
(parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL)) {
if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
MS_LOG(EXCEPTION) << "device_num must be set when use shard function";
}
auto device_num_shard = parallel::ParallelContext::GetInstance()->device_num();
SetStrategyForShard(root, all_nodes, device_num_shard);
}
if (!HasVirtualDataset(all_nodes)) {
InsertVirtualDataset(root, all_nodes);
}
@ -218,6 +485,7 @@ bool PipelineSplit(const ResourcePtr &res) {
MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
return true;
}
auto global_rank = GetRank();
auto world_group = GetWorldGroup();
uint32_t world_rank_size = 0;
@ -231,6 +499,7 @@ bool PipelineSplit(const ResourcePtr &res) {
} else {
device_num = parallel::ParallelContext::GetInstance()->device_num();
}
if (device_num < 1) {
MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'device_num' must be positive, "
"but got the value of device_num: "

View File

@ -861,11 +861,6 @@ void CheckPyNativeContext() {
MS_EXCEPTION_IF_NULL(parallel_context);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
const auto &parallel_mode = parallel_context->parallel_mode();
if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
}
}
py::object GetDstType(const TypeId &type_id) {
@ -2672,6 +2667,34 @@ std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, c
return cell_id;
}
void GradExecutor::MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) {
auto func_graph = resource->func_graph();
std::vector<size_t> in_ms_function;
auto parameters = func_graph->parameters();
for (size_t i = 0; i < parameters.size(); i++) {
auto param = parameters[i]->cast<ParameterPtr>();
if (!param->has_default()) {
continue;
}
auto iter = std::find(ms_function_params_.begin(), ms_function_params_.end(), param->name());
if (iter != ms_function_params_.end()) {
in_ms_function.push_back(1);
} else {
in_ms_function.push_back(0);
}
}
auto ret = func_graph->get_return();
auto ret_cnode = ret->cast<CNodePtr>();
auto grads = ret_cnode->input(1)->cast<CNodePtr>();
for (size_t i = 1; i < grads->inputs().size(); i++) {
if (in_ms_function[i - 1]) {
auto cnode = grads->input(i)->cast<CNodePtr>();
cnode->set_parallel(true);
}
}
}
void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
const py::object &weights, const py::object &grad_position, const py::args &args) {
MS_EXCEPTION_IF_NULL(ret);
@ -2716,6 +2739,10 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
compile::SetMindRTEnable();
resource->SetResult(pipeline::kBackend, compile::CreateBackend());
MS_LOG(DEBUG) << "Start task emit action";
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL) {
MarkMsFunctionNodes(resource);
}
TaskEmitAction(resource);
MS_LOG(DEBUG) << "Start execute action";
ExecuteAction(resource);
@ -3264,6 +3291,15 @@ py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &a
FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
MS_EXCEPTION_IF_NULL(grad_graph);
GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL) {
for (auto &parameter : ms_func_graph->parameters()) {
auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) {
ms_function_params_.push_back(param->name());
}
}
}
set_graph_phase("");
return ret;
}

View File

@ -292,6 +292,7 @@ class GradExecutor {
MS_EXCEPTION_IF_NULL(graph_info);
graph_info->node_map[id] = std::make_pair(node, index);
}
void MarkMsFunctionNodes(const pipeline::ResourcePtr &resource);
private:
bool grad_flag_{false};
@ -311,6 +312,8 @@ class GradExecutor {
TopCellInfoPtr top_cell_{nullptr};
// Records forwrad cell, the bottom is top cell
std::stack<std::string> cell_stack_;
// Stores parameter in ms_function
std::vector<std::string> ms_function_params_;
// For high grad of bprop
std::stack<std::pair<std::string, bool>> bprop_grad_stack_;
std::vector<std::string> bprop_cell_list_;

View File

@ -454,7 +454,8 @@ bool HcclAdapter::FinalizeHcclComm() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!task_sink) {
auto execution_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!task_sink && execution_mode == kGraphMode) {
HcclCollectiveGroup::instance().DestroyCommGroup();
}
if (hccl_comm_ == nullptr) {

View File

@ -703,6 +703,16 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \param node A node debug info of an anf node.
void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
/// \brief Check whether this node is in ms_function or not in PyNative Mode.
///
/// \return True if in ms_function, otherwise false.
bool is_parallel() const { return is_parallel_; }
/// \brief Set is_parallel_ for CNode.
///
/// \param[in] is_parallel_ Boolean.
void set_parallel(bool parallel) { is_parallel_ = parallel; }
private:
std::vector<AnfNodePtr> inputs_;
VarPtr func_graph_as_var_;
@ -710,6 +720,8 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
bool in_forward_flag_ = false;
bool effect_handled_ = false;
bool is_load_ = false;
// is_parallel represents whether this cnode lies in ms_function or not in PyNative Mode
bool is_parallel_ = false;
// inputs_value_ store cnode input value and id in pynative mode
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;

View File

@ -23,6 +23,7 @@ import ast
import importlib
from collections import OrderedDict
from functools import wraps
import numpy as np
from mindspore import context
from mindspore import log as logger
@ -30,6 +31,7 @@ from mindspore._extends.remote import kernel_build_server
from .tensor import Tensor as MsTensor
from .tensor import CSRTensor as MsCSRTensor
from .tensor import COOTensor as MsCOOTensor
from .initializer import initializer
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, COOTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
@ -200,6 +202,7 @@ class _MindsporeFunctionExecutor:
self.obj = None
if obj and hasattr(obj, fn.__name__):
self.obj = obj
self.shard_parent_obj = obj
self._graph_executor = GraphExecutor_.get_instance()
self._create_time = ms_create_time
@ -223,6 +226,36 @@ class _MindsporeFunctionExecutor:
if enable_compile_cache is True or enable_compile_cache == "1":
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
def _parallel_process_for_ms_function(self, phase):
"""Set parameter and optimizer states data according to sliced shape for shard"""
obj = self.shard_parent_obj if self.obj is None else self.obj
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
replace = obj.init_parameters_data(auto_parallel_mode=True)
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
self._graph_executor.updata_param_node_default_input(phase, new_param)
obj.load_parameter_slice(None)
if _pynative_executor.get_optimizer():
params = obj.trainable_params()
opt_params = _pynative_executor.get_optimizer().trainable_params()
opt_states = []
for opt_param in opt_params:
for param in params:
if opt_param.name.find(param.name) > 0:
opt_states.append(opt_param)
obj.parameter_layout_dict[opt_param.name] = obj.parameter_layout_dict[param.name]
continue
states_tuple = (opt_states[:len(params)], opt_states[len(params):]) if len(opt_states) != len(params) \
else (opt_states[:len(params)],)
for states in states_tuple:
for param, state in zip(params, states):
if param.shape != state.shape:
state.set_data(initializer(state.init, param.shape), True)
_pynative_executor.get_top_cell().parameter_layout_dict = obj.parameter_layout_dict
def compile(self, args_list, method_name):
"""Returns pipeline for the given args."""
# Verify the signature for both function and method
@ -271,6 +304,10 @@ class _MindsporeFunctionExecutor:
else:
self._graph_executor.set_weights_values(self.obj.parameters_dict())
is_compile = self._graph_executor.compile(self.obj, args_list, phase, True)
if is_pynative_parallel():
self._parallel_process_for_ms_function(phase)
if not is_compile:
raise RuntimeError("Executor compile failed.")
if context.get_context("enable_ge"):
@ -284,7 +321,18 @@ class _MindsporeFunctionExecutor:
if self.obj is not None:
args_list = args_list[1:]
phase = self.compile(args_list, self.fn.__name__)
if is_pynative_parallel() and not hasattr(self.shard_parent_obj, "keep_input_unchanged"):
device_num = context.get_auto_parallel_context('device_num')
new_args_list = ()
for arg in args_list:
if isinstance(arg, MsTensor):
new_shape = (arg.shape[0] * device_num,) + arg.shape[1:]
new_args_list += (MsTensor(np.zeros(shape=new_shape), arg.dtype),)
else:
new_args_list += (arg,)
phase = self.compile(new_args_list, self.fn.__name__)
else:
phase = self.compile(args_list, self.fn.__name__)
if context.get_context("precompile_only"):
return None
@ -371,6 +419,8 @@ def ms_function(fn=None, obj=None, input_signature=None):
process_obj = None
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
if process_obj is None and is_pynative_parallel():
process_obj = obj
out = _MindsporeFunctionExecutor(func, ms_create_time, input_signature, process_obj)(*args)
return out
@ -380,6 +430,11 @@ def ms_function(fn=None, obj=None, input_signature=None):
return wrap_mindspore(fn)
return wrap_mindspore
def is_pynative_parallel():
run_mode = context.get_context('mode')
parallel_mode = context.get_auto_parallel_context('parallel_mode')
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
def _get_auto_split_param_names(parameter_layout_dict):
auto_split_param_names = []
@ -443,6 +498,8 @@ class _PynativeExecutor:
self._executor = PynativeExecutor_.get_instance()
self._executor.set_py_exe_path(sys.executable)
self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
self._optimizer = None
self._top_cell = None
def new_graph(self, obj, *args, **kwargs):
self._executor.new_graph(obj, *args, *(kwargs.values()))
@ -508,6 +565,12 @@ class _PynativeExecutor:
def set_hook_changed(self, cell):
self._executor.set_hook_changed(cell)
def get_optimizer(self):
return self._optimizer
def get_top_cell(self):
return self._top_cell
def __call__(self, obj, *args, **kwargs):
args = args + tuple(kwargs.values())
return self._executor(obj, args)

View File

@ -26,6 +26,7 @@ from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.common.hook_handle import HookHandle
from mindspore.context import ParallelMode
from mindspore.ops.composite import Shard
from .. import context
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
from .._checkparam import Validator
@ -380,10 +381,13 @@ class Cell(Cell_):
self.parameter_broadcast_done = True
def run_construct(self, cast_inputs, kwargs):
"""Run the construct function"""
if self._enable_forward_pre_hook:
cast_inputs = self.run_forward_pre_hook(cast_inputs)
if self.enable_backward_hook:
output = self.run_backward_hook(*cast_inputs)
elif hasattr(self, "_shard_fn"):
output = self._shard_fn(*cast_inputs, **kwargs)
else:
output = self.construct(*cast_inputs, **kwargs)
if self._enable_forward_hook:
@ -443,6 +447,33 @@ class Cell(Cell_):
for prim in all_prims:
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
def shard(self, in_axes, out_axes, device="Ascend", level=0):
"""
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
generated by sharding propagation. In_axes and out_axes define the input and output layout respectively.
In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of
this input/output and None represents data_parallel.
Note:
Only effective in PYNATIVE_MODE and auto_parallel_context in either ParallelMode.AUTO_PARALLEL and
search_mode = sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL.
Examples:
>>> from mindspore.ops import functional as F
>>> class example(nn.Cell):
>>> def __init__(self):
>>> self.block1 = Block()
>>> self.block2 = Block()
>>> self.block2.shard(in_axes=(None, (2, 1)), out_axes=(None,))
>>> # self.parallel_block = F.shard(self.block2, in_axes=(None, (2, 1)), out_axes=(None,))
>>> def construct(self, x):
>>> x = self.block1(x)
>>> x = self.block2(x)
>>> return x
"""
shard_fn = Shard()
fn = shard_fn(self, in_axes, out_axes, device, level)
object.__setattr__(self, "_shard_fn", fn)
return self
class CellGuard:
def __enter__(self):
_pynative_executor.set_lazy_build(True)
@ -485,6 +516,8 @@ class Cell(Cell_):
# Run in PyNative mode.
if _pynative_executor.is_top_cell():
_pynative_executor.set_lazy_build(True)
_pynative_executor._optimizer = getattr(self, "optimizer", None)
_pynative_executor._top_cell = self
# There many Casts in parameter_broadcast. Enable lazy_build and build faster.
self._do_parameter_broadcast()

View File

@ -754,7 +754,10 @@ class Shard(Shard_):
self.device = None
self.level = None
def __call__(self, fn, in_axes, out_axes, device, level=0):
def __call__(self, fn, in_axes, out_axes, device="Ascend", level=0):
if context.get_context("mode") != context.PYNATIVE_MODE or \
context.get_auto_parallel_context("parallel_mode") not in ["semi_auto_parallel", "auto_parallel"]:
raise AssertionError(f"'Shard' only supports semi_auto/auto parallel under PyNative mode")
if not isinstance(in_axes, tuple):
raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}")
if not isinstance(out_axes, tuple):
@ -771,7 +774,7 @@ class Shard(Shard_):
return self.shard_fn
shard_ = Shard()
@ms_function
@ms_function(obj=fn)
def after_shard(*args):
return shard_(fn, in_axes, out_axes, device, level)(*args)

View File

@ -345,7 +345,7 @@ def vjp(fn, inputs, v):
return wrap_container(inputs, v)
shard_fn = Shard()
def shard(fn, in_axes, out_axes, device, level=0):
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
return shard_fn(fn, in_axes, out_axes, device, level)
def narrow(inputs, axis, start, length):

View File

@ -21,6 +21,7 @@ from mindspore.common.api import _wrap_func
from mindspore.log import _LogActionOnce
from mindspore import context, log as logger
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.common.parameter import Parameter
from .._c_expression import Primitive_, real_run_op, prim_type
from .._checkparam import Validator
from . import signature as sig
@ -283,6 +284,9 @@ class Primitive(Primitive_):
def __call__(self, *args):
should_elim, output = self.check_elim(*args)
for arg in args:
if isinstance(arg, Parameter) and arg.has_init:
arg.init_data()
if should_elim:
return output
return _run_op(self, self.name, args)

View File

@ -355,11 +355,6 @@ class _AutoParallelContext:
ValueError: If parallel mode is not supported.
"""
self.check_context_handle()
run_mode = context.get_context("mode")
if run_mode == context.PYNATIVE_MODE and parallel_mode not in (
context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE):
raise ValueError(f"Pynative Only support STAND_ALONE and DATA_PARALLEL for ParallelMode, "
f"but got {parallel_mode.upper()}.")
ret = self._context_handle.set_parallel_mode(parallel_mode)
if ret is False:
raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', "

View File

@ -17,6 +17,7 @@
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication import GlobalComm
from ..common import ms_function
_allgather_cell = None
@ -31,6 +32,7 @@ class AllGatherCell(Cell):
self.allgather = AllGather(group)
@ms_function()
def construct(self, x):
x = self.allgather(x)

View File

@ -704,6 +704,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
# pipeline parallel need to be supported here later
if mp_weight:
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
object.__setattr__(allgather_net, "keep_input_unchanged", True)
elif opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
@ -815,7 +816,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
enc_key = Validator.check_isinstance('enc_key', kwargs['enc_key'], bytes)
enc_mode = 'AES-GCM'
if 'enc_mode' in kwargs.keys():
enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str)
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
dataset = kwargs['dataset'] if 'dataset' in kwargs.keys() else None
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
else:
@ -961,8 +962,8 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
write_data = raw_data + bytes(append_size)
offset += (data_length + append_size)
if is_encrypt():
write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
f.write(write_data)
# save graph
@ -973,9 +974,9 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'],
len(kwargs['enc_key']),
kwargs['enc_mode'])
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
len(kwargs.get('enc_key')),
kwargs.get('enc_mode'))
model_file.write(model_string)
os.chmod(graph_file_name, stat.S_IRUSR)
@ -1000,7 +1001,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
net_dict = net.parameters_dict()
model.ParseFromString(mindir_stream)
if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
if 'dataset' in kwargs.keys() and kwargs.get('dataset') is not None:
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
dataset = kwargs['dataset']
_save_dataset_to_mindir(model, dataset)
@ -1035,8 +1036,8 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'), len(kwargs.get('enc_key')),
kwargs.get('enc_mode'))
f.write(model_string)
os.chmod(file_name, stat.S_IRUSR)
@ -1109,8 +1110,8 @@ def _quant_export(network, *inputs, file_format, **kwargs):
quant_net = copy.deepcopy(network)
quant_net._create_time = int(time.time() * 1e9)
mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']
mean = 127.5 if kwargs.get('mean', None) is None else kwargs.get('mean')
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs.get('std_dev')
mean = Validator.check_value_type("mean", mean, (int, float))
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))

View File

@ -117,7 +117,7 @@ class TestSharedParameterCast:
"""
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1)
self.cat_fp16_from_ir(target_count=27)
self.cat_fp16_from_ir(target_count=28)
def test_optimizer_fp16_micro_batch(self):
"""
@ -127,7 +127,7 @@ class TestSharedParameterCast:
"""
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=2)
self.cat_fp16_from_ir(target_count=41)
self.cat_fp16_from_ir(target_count=42)
def test_optimizer_fp16_pipeline(self):
"""
@ -138,7 +138,7 @@ class TestSharedParameterCast:
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1,
stages=1, micro_size=1)
self.cat_fp16_from_ir(target_count=27)
self.cat_fp16_from_ir(target_count=28)
def test_optimizer_fp16_pipeline_micro_batch(self):
"""
@ -149,4 +149,4 @@ class TestSharedParameterCast:
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=2,
stages=1, micro_size=1)
self.cat_fp16_from_ir(target_count=41)
self.cat_fp16_from_ir(target_count=42)