forked from mindspore-Ecosystem/mindspore
!30469 add shard function to support part of the graph executed in auto_parallel under pynative mode
Merge pull request !30469 from wangjun/0223_pp
This commit is contained in:
commit
0341d96dd6
|
@ -341,6 +341,14 @@
|
|||
|
||||
.. note:: 仅在全自动并行(AUTO_PARALLEL)模式下生效。
|
||||
|
||||
.. py:method:: shard(in_axes, out_axes, device="Ascend", level=0)
|
||||
|
||||
指定输入/输出tensor的分布策略,其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个cell以图模式进行分布式执行。
|
||||
in_axes/out_axes需要为元组类型,其中的每一个元素指定对应的输入/输出的tensor分布策略,其类型需要为元组,
|
||||
可参考:`mindspore.ops.Primitive.shard`的描述,也可以设置为None,会默认以数据并行执行。
|
||||
|
||||
.. note:: 需设置为Pyative模式,并且全自动并行(AUTO_PARALLEL),同时search mode为sharding_propagation,或半自动并行(SEMI_AUTO_PARALLEL)。
|
||||
|
||||
.. py:method:: set_grad(requires_grad=True)
|
||||
|
||||
Cell的梯度设置。在PyNative模式下,该参数指定Cell是否需要梯度。如果为True,则在执行正向网络时,将生成需要计算梯度的反向网络。
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "backend/graph_compiler/transform.h"
|
||||
#include "backend/common/session/session_factory.h"
|
||||
#include "runtime/op_builder/op_lazy_builder.h"
|
||||
|
@ -451,10 +452,12 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
real_execution_mode_ = ms_execution_mode_;
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
auto is_parallel = (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL);
|
||||
|
||||
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
|
||||
if (ms_execution_mode_ == kPynativeMode &&
|
||||
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1)) {
|
||||
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1) && !is_parallel) {
|
||||
real_execution_mode_ = kGraphMode;
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
pipeline::SetRunMode(func_graph, this);
|
||||
|
@ -891,7 +894,7 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||
|
||||
// Save grad node to Bucket
|
||||
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
|
||||
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunOutputReplace(const CNodePtr &forwa
|
|||
const FuncGraphPtr &fprop_graph,
|
||||
const CNodePtr &cnode_morph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_morph);
|
||||
if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient)) {
|
||||
if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode_morph, prim::kPrimMirror)) {
|
||||
return {};
|
||||
}
|
||||
// Use manager to get the link relation among nodes.
|
||||
|
@ -176,7 +176,7 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bp
|
|||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
// Parameter, ValueNode and StopGradient CNode no need to replace.
|
||||
if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
|
||||
IsPrimitiveCNode(input_node, prim::kPrimStopGradient)) {
|
||||
IsPrimitiveCNode(input_node, prim::kPrimStopGradient) || IsPrimitiveCNode(input_node, prim::kPrimMirror)) {
|
||||
continue;
|
||||
}
|
||||
// Replace forward input node by its output value.
|
||||
|
@ -187,6 +187,9 @@ std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bp
|
|||
MS_EXCEPTION_IF_NULL(output_vnode_i);
|
||||
output_vnode_i->set_has_new_value(true);
|
||||
manager->Replace(paras[i], output_vnode_i);
|
||||
if (IsPrimitiveCNode(cnode_i, prim::kPrimLoad)) {
|
||||
para_ref_size += 1;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
|
||||
// Save forward input node when it used in bprop graph.
|
||||
if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {
|
||||
|
|
|
@ -61,10 +61,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
|
||||
arithmetic_simplify2_ =
|
||||
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
|
||||
special_op_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, prim::kPrimCellBackwardHook,
|
||||
prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
special_op_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
|
||||
prim::kPrimCellBackwardHook, prim::kPrimPrintShapeType, prim::kPrimGetRefValue});
|
||||
ad_related_special_op_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "ad_related_special_op_eliminate",
|
||||
{prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
|
||||
zero_like_fill_zero_ =
|
||||
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||
|
|
|
@ -35,6 +35,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr arithmetic_simplify_;
|
||||
SubstitutionPtr arithmetic_simplify2_;
|
||||
SubstitutionPtr special_op_eliminate_;
|
||||
SubstitutionPtr ad_related_special_op_eliminate_;
|
||||
SubstitutionPtr zero_like_fill_zero_;
|
||||
SubstitutionPtr adjust_all_reduce_mul_add_;
|
||||
SubstitutionPtr float_depend_g_call_;
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/irpass/shard_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace internal {
|
||||
AnfNodePtr ExpandShard(const CNodePtr &node) {
|
||||
auto vnode = node->input(1)->cast<ValueNodePtr>();
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return NewValueNode(func_graph);
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
bool ExpandShardPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
GetShardPrim(func_graph);
|
||||
bool change = false;
|
||||
auto manager = optimizer->manager();
|
||||
for (auto shard_node : shard_nodes_) {
|
||||
auto expanded_shard = internal::ExpandShard(shard_node);
|
||||
manager->Replace(shard_node, expanded_shard);
|
||||
change = true;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
void ExpandShardPrim::GetShardPrim(const FuncGraphPtr &func_graph) {
|
||||
shard_nodes_.clear();
|
||||
AnfNodePtr ret = func_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
|
||||
shard_nodes_.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SHARD_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SHARD_ELIMINATE_H_
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
class ExpandShardPrim {
|
||||
public:
|
||||
ExpandShardPrim() = default;
|
||||
virtual ~ExpandShardPrim() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
|
||||
void GetShardPrim(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
std::vector<CNodePtr> shard_nodes_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SHARD_ELIMINATE_H_
|
|
@ -347,6 +347,7 @@ constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax";
|
|||
constexpr char SIMPLE_MEAN[] = "SimpleMean";
|
||||
constexpr char FLATTEN[] = "Flatten";
|
||||
constexpr char J[] = "J";
|
||||
constexpr char SHARD[] = "Shard";
|
||||
constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info";
|
||||
constexpr char COS[] = "Cos";
|
||||
constexpr char ACOS[] = "ACos";
|
||||
|
|
|
@ -935,10 +935,30 @@ void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group,
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr PynativeParallelGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
FuncGraphPtr real_graph = root;
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto expect_shard_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (expect_shard_prim->name() != SHARD) {
|
||||
continue;
|
||||
}
|
||||
real_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
}
|
||||
return real_graph;
|
||||
}
|
||||
|
||||
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
vector<std::string> last_forward_node_ids;
|
||||
vector<size_t> last_indexs;
|
||||
FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs);
|
||||
auto real_graph = PynativeParallelGraph(root, all_nodes);
|
||||
FindLastNodesUniqueId(real_graph, &last_forward_node_ids, &last_indexs);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
for (auto &node : all_nodes) {
|
||||
// here insert virtualoutput node
|
||||
|
@ -2189,9 +2209,9 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
|
|||
auto tuple_index = GetTupleGetItemIndex(cnode);
|
||||
auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
|
||||
if (!layout_ptr) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node "
|
||||
"before tuple_getitem!";
|
||||
MS_LOG(EXCEPTION) << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a "
|
||||
"parallel care node "
|
||||
"before tuple_getitem!";
|
||||
}
|
||||
return layout_ptr;
|
||||
}
|
||||
|
@ -2485,8 +2505,8 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
|||
if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (expect_j_prim->name() != J) {
|
||||
auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (expect_prim->name() != J && expect_prim->name() != SHARD) {
|
||||
continue;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(cnode->input(1))) {
|
||||
|
@ -2513,6 +2533,12 @@ void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsPynativeParallel() {
|
||||
auto parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
return (execution_mode == kPynativeMode) && (parallel_mode == SEMI_AUTO_PARALLEL || parallel_mode == AUTO_PARALLEL);
|
||||
}
|
||||
|
||||
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||
std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
|
@ -2607,7 +2633,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
|
||||
}
|
||||
// insert backward ops
|
||||
if (has_backward) {
|
||||
if (has_backward || IsPynativeParallel()) {
|
||||
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
|
||||
}
|
||||
|
||||
|
@ -3075,8 +3101,9 @@ bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
|
|||
"the input parallel strategy when using context.set_auto_parallel_context(dataset_strategy)"
|
||||
" to configure the input strategy.";
|
||||
}
|
||||
return (!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
|
||||
current_stage == split_stage_num - 1);
|
||||
return ((!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
|
||||
current_stage == split_stage_num - 1) ||
|
||||
IsPynativeParallel());
|
||||
}
|
||||
|
||||
static void HandleGroupInfo(const FuncGraphPtr &root) {
|
||||
|
|
|
@ -737,6 +737,17 @@ bool EliminateForwardCNode(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EliminateAdRelatedSpecialOpNode(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
|
||||
}
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null.";
|
||||
}
|
||||
return EliminateAdRelatedSpecialOpOptPass(res);
|
||||
}
|
||||
|
||||
bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -1395,6 +1406,9 @@ std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
|
|||
// eliminate forward cnode for grad graph
|
||||
(void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
|
||||
|
||||
// eliminate the virtual mirror node
|
||||
(void)actions.emplace_back(std::make_pair("eliminate_ad_related_special_op_node", EliminateAdRelatedSpecialOpNode));
|
||||
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,8 @@
|
|||
#include "frontend/optimizer/irpass/branch_culling.h"
|
||||
#include "frontend/optimizer/irpass/meta_fg_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/ge_specialized_prepare.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/shard_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/parameter_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
|
@ -188,6 +190,7 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
|
|||
irpass.reshape_eliminate_,
|
||||
irpass.switch_simplify_,
|
||||
irpass.addn_zero_filter_,
|
||||
irpass.ad_related_special_op_eliminate_,
|
||||
});
|
||||
opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
|
||||
OptPassGroupMap map({
|
||||
|
@ -366,6 +369,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
||||
{"virtual_dataset", virtual_dataset},
|
||||
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
|
||||
{"shard", opt::OptPassConfig(opt::irpass::ExpandShardPrim())},
|
||||
{"meta_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaFg())},
|
||||
{"after_resolve", after_resolve_pass},
|
||||
{"a_after_grad", a_after_grad},
|
||||
|
@ -723,6 +727,21 @@ bool PynativeOptPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &res) {
|
||||
auto func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
opt::OptPassConfig ad_related_special_op_eliminate = opt::OptPassConfig({
|
||||
irpass.ad_related_special_op_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"ad_related_special_op_eliminate", ad_related_special_op_eliminate},
|
||||
});
|
||||
auto ad_related_special_op_eliminate_opt = opt::Optimizer::MakeOptimizer("ad_related_special_op_eliminate", res, map);
|
||||
(void)ad_related_special_op_eliminate_opt->step(func_graph, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph->manager());
|
||||
|
|
|
@ -47,6 +47,7 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res);
|
|||
bool InferenceOptPreparePass(const ResourcePtr &res);
|
||||
void ReclaimOptimizer();
|
||||
bool PynativeOptPass(const ResourcePtr &res);
|
||||
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &res);
|
||||
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph);
|
||||
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
|
||||
|
|
|
@ -137,12 +137,12 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
FuncGraphPtr fun_graph = nullptr;
|
||||
if (!root->has_flag(mindspore::parallel::TRAINING)) {
|
||||
graph_sets.insert(root);
|
||||
}
|
||||
if (expect_j_prim->name() == mindspore::parallel::J) {
|
||||
if (expect_prim->name() == mindspore::parallel::J || expect_prim->name() == mindspore::parallel::SHARD) {
|
||||
if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
|
||||
fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
|
||||
} else {
|
||||
|
@ -192,6 +192,262 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
|
|||
}
|
||||
}
|
||||
|
||||
void GenerateDefaultStrategy(const ValueNodePtr &axes, const std::vector<AnfNodePtr> &nodes, const int64_t device_num,
|
||||
std::vector<std::vector<int64_t>> *default_strategy) {
|
||||
auto strategies = axes->value()->cast<ValueTuplePtr>()->value();
|
||||
size_t i = 0;
|
||||
for (auto &strategy : strategies) {
|
||||
auto node = nodes[i];
|
||||
if (strategy->isa<None>()) {
|
||||
auto node_size = AnfAlgo::GetOutputInferShape(node, 0).size();
|
||||
std::vector<int64_t> current_d_strategy(node_size, 1);
|
||||
if (node_size >= 1) {
|
||||
current_d_strategy[0] = device_num;
|
||||
}
|
||||
default_strategy->push_back(current_d_strategy);
|
||||
} else {
|
||||
auto current_strategy = GetValue<std::vector<int64_t>>(strategy);
|
||||
default_strategy->push_back(current_strategy);
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckLayout(const ValueNodePtr &axes, bool *need_default_strategy, size_t *axes_size) {
|
||||
auto strategies = axes->value()->cast<ValueTuplePtr>()->value();
|
||||
for (auto &strategy : strategies) {
|
||||
*axes_size += 1;
|
||||
if (strategy->isa<None>()) {
|
||||
*need_default_strategy = true;
|
||||
continue;
|
||||
}
|
||||
if (!strategy->isa<ValueTuple>()) {
|
||||
return false;
|
||||
}
|
||||
auto elements = strategy->cast<ValueTuplePtr>()->value();
|
||||
|
||||
for (auto &element : elements) {
|
||||
if (!element->isa<Int64Imm>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void HandleStrategyForOneHot(std::vector<ValuePtr> *strategy) {
|
||||
// onehot needs to set layout for output, modify the strategy with an additional dimension
|
||||
auto input_strategy = GetValue<std::vector<int64_t>>(strategy->at(0));
|
||||
input_strategy.push_back(1);
|
||||
strategy->at(0) = MakeValue(input_strategy);
|
||||
}
|
||||
|
||||
void HandleStrategyForMatMul(std::vector<ValuePtr> *strategy, const CNodePtr &cnode) {
|
||||
// handle strategy for matmul to deal with corresponding dimension
|
||||
auto left_matrix_strategy = GetValue<std::vector<int64_t>>(strategy->at(0));
|
||||
auto right_matrix_strategy = GetValue<std::vector<int64_t>>(strategy->at(1));
|
||||
auto index_a = left_matrix_strategy.size() - 1;
|
||||
auto index_b = index_a - 1;
|
||||
auto attrs = GetCNodePrimitive(cnode)->attrs();
|
||||
bool transpose_a = attrs[parallel::TRANSPOSE_A]->cast<BoolImmPtr>()->value();
|
||||
bool transpose_b = attrs[parallel::TRANSPOSE_B]->cast<BoolImmPtr>()->value();
|
||||
if (transpose_a) {
|
||||
index_a -= 1;
|
||||
}
|
||||
if (transpose_b) {
|
||||
index_b += 1;
|
||||
}
|
||||
if (left_matrix_strategy[index_a] != right_matrix_strategy[index_b]) {
|
||||
if (left_matrix_strategy[index_a] == 1) {
|
||||
left_matrix_strategy[index_a] = right_matrix_strategy[index_b];
|
||||
} else {
|
||||
right_matrix_strategy[index_b] = left_matrix_strategy[index_a];
|
||||
}
|
||||
strategy->at(0) = MakeValue(left_matrix_strategy);
|
||||
strategy->at(1) = MakeValue(right_matrix_strategy);
|
||||
}
|
||||
}
|
||||
|
||||
void GetInputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *input_nodes) {
|
||||
auto parameters = func_graph->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
if (parameter->cast<ParameterPtr>()->name() == "u" || parameter->cast<ParameterPtr>()->name() == "io") {
|
||||
continue;
|
||||
}
|
||||
input_nodes->push_back(parameter);
|
||||
}
|
||||
}
|
||||
|
||||
void GetOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_nodes) {
|
||||
auto return_node = func_graph->get_return();
|
||||
auto real_return_node = return_node->cast<CNodePtr>()->input(1);
|
||||
while (IsPrimitiveCNode(real_return_node, prim::kPrimDepend)) {
|
||||
real_return_node = real_return_node->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
if (!IsPrimitiveCNode(real_return_node, prim::kPrimMakeTuple)) {
|
||||
output_nodes->push_back(real_return_node);
|
||||
} else {
|
||||
auto cnode = real_return_node->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
output_nodes->push_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const int64_t &device_num) {
|
||||
for (size_t i = 0; i < strategies.size(); ++i) {
|
||||
auto strategy = strategies[i];
|
||||
int64_t required_num = 1;
|
||||
std::for_each(strategy.begin(), strategy.end(), [&](int64_t const &data) { required_num *= data; });
|
||||
if (required_num > device_num) {
|
||||
MS_LOG(ERROR) << "required device number: " << required_num
|
||||
<< " is larger than available device number: " << device_num << " at index: " << i;
|
||||
return false;
|
||||
}
|
||||
if (device_num % required_num != 0) {
|
||||
MS_LOG(ERROR) << "required device number: " << required_num
|
||||
<< " is not divisible by device number: " << device_num << " at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, const int64_t &device_num) {
|
||||
auto out_axes_tuple = out_axes->cast<ValueNodePtr>();
|
||||
bool need_default_strategy = false;
|
||||
size_t out_axes_size = 0;
|
||||
if (!IsValueNode<ValueTuple>(out_axes_tuple) ||
|
||||
!CheckLayout(out_axes_tuple, &need_default_strategy, &out_axes_size)) {
|
||||
MS_LOG(EXCEPTION) << "out_axes should be a two-dimension tuple";
|
||||
}
|
||||
std::vector<AnfNodePtr> output_nodes;
|
||||
GetOutputNodes(func_graph, &output_nodes);
|
||||
if (output_nodes.size() != out_axes_size) {
|
||||
MS_LOG(EXCEPTION) << "Output number: " << output_nodes.size()
|
||||
<< " is not equal to out_axes number: " << out_axes_size;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> output_strategy;
|
||||
if (need_default_strategy) {
|
||||
GenerateDefaultStrategy(out_axes_tuple, output_nodes, device_num, &output_strategy);
|
||||
} else {
|
||||
output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_axes_tuple->value());
|
||||
}
|
||||
MS_LOG(WARNING) << "The output strategy will be overwritten as data-parallel";
|
||||
|
||||
for (size_t i = 0; i < output_nodes.size(); ++i) {
|
||||
auto node = output_nodes[i];
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (output_shape.size() != output_strategy[i].size()) {
|
||||
MS_LOG(EXCEPTION) << "Output dimension: " << output_shape.size()
|
||||
<< " is not equal to out_axes dimension: " << output_strategy[i].size() << " at index " << i;
|
||||
}
|
||||
std::vector<ValuePtr> elements;
|
||||
elements.push_back(MakeValue(output_strategy[i]));
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
auto attrs_temp = prim->attrs();
|
||||
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
||||
attrs_temp[parallel::OUT_STRATEGY] = strategy;
|
||||
(void)prim->SetAttrs(attrs_temp);
|
||||
}
|
||||
}
|
||||
|
||||
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, const int64_t &device_num) {
|
||||
auto in_axes_tuple = in_axes->cast<ValueNodePtr>();
|
||||
bool need_default_strategy = false;
|
||||
size_t in_axes_size = 0;
|
||||
if (!IsValueNode<ValueTuple>(in_axes_tuple) || !CheckLayout(in_axes_tuple, &need_default_strategy, &in_axes_size)) {
|
||||
MS_LOG(EXCEPTION) << "in_axes should be a two-dimension tuple";
|
||||
}
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
GetInputNodes(func_graph, &input_nodes);
|
||||
if (input_nodes.size() != in_axes_size) {
|
||||
MS_LOG(EXCEPTION) << "Input numbers: " << input_nodes.size()
|
||||
<< " is not equal to in_axes numbers: " << in_axes_size;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_strategy;
|
||||
if (need_default_strategy) {
|
||||
GenerateDefaultStrategy(in_axes_tuple, input_nodes, device_num, &input_strategy);
|
||||
} else {
|
||||
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_axes_tuple->value());
|
||||
}
|
||||
if (!CheckDeviceNum(input_strategy, device_num)) {
|
||||
MS_LOG(EXCEPTION) << "check device number failed";
|
||||
}
|
||||
std::set<CNodePtr> concerned_nodes;
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
auto parameters = func_graph->parameters();
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
auto parameter = parameters[i];
|
||||
if (parameter->cast<ParameterPtr>()->name() == "u" || parameter->cast<ParameterPtr>()->name() == "io") {
|
||||
continue;
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(parameter, 0);
|
||||
if (output_shape.size() != input_strategy[i].size()) {
|
||||
MS_LOG(EXCEPTION) << "Input dimension: " << output_shape.size()
|
||||
<< " is not equal to in_axes dimension: " << input_strategy[i].size() << " at index " << i;
|
||||
}
|
||||
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
|
||||
for (auto ¶m_pair : param_sub_set) {
|
||||
CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
|
||||
concerned_nodes.insert(param_cnode);
|
||||
}
|
||||
}
|
||||
for (auto &cnode : concerned_nodes) {
|
||||
auto current_inputs = cnode->inputs();
|
||||
std::vector<ValuePtr> elements;
|
||||
for (size_t i = 1; i < current_inputs.size(); ++i) {
|
||||
auto current_input = current_inputs[i];
|
||||
if (current_input->isa<ValueNode>()) {
|
||||
auto current_value = current_input->cast<ValueNodePtr>()->value();
|
||||
if (!current_value->isa<mindspore::tensor::Tensor>()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto iter = std::find(parameters.begin(), parameters.end(), current_input);
|
||||
if (iter != parameters.end()) {
|
||||
elements.push_back(MakeValue(input_strategy[iter - parameters.begin()]));
|
||||
} else {
|
||||
auto shape = current_input->Shape()->cast<abstract::ShapePtr>();
|
||||
auto dimension = shape->shape().size();
|
||||
std::vector<int64_t> default_strategy(dimension, 1);
|
||||
elements.push_back(MakeValue(default_strategy));
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMatMul) || IsPrimitiveCNode(cnode, prim::kPrimBatchMatMul)) {
|
||||
HandleStrategyForMatMul(&elements, cnode);
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimOneHot)) {
|
||||
HandleStrategyForOneHot(&elements);
|
||||
}
|
||||
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto attrs_temp = prim->attrs();
|
||||
attrs_temp[parallel::IN_STRATEGY] = strategy;
|
||||
(void)prim->SetAttrs(attrs_temp);
|
||||
}
|
||||
}
|
||||
|
||||
void SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const int64_t &device_num) {
|
||||
root->set_flag("training", true);
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
|
||||
root->set_flag("auto_parallel", true);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto vnode = cnode->input(1)->cast<ValueNodePtr>();
|
||||
auto in_axes = cnode->input(2);
|
||||
auto out_axes = cnode->input(3);
|
||||
ScopeGuard scope_guard(vnode->scope());
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
SetInputLayout(func_graph, in_axes, device_num);
|
||||
SetOutputLayout(func_graph, out_axes, device_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only auto_parallel and semi_auto_parallel support PipelineSplit
|
||||
bool PipelineSplit(const ResourcePtr &res) {
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
|
@ -205,11 +461,22 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto manager = res->manager();
|
||||
auto root = res->func_graph();
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
if ((execution_mode == kPynativeMode) &&
|
||||
(parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL)) {
|
||||
if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
|
||||
MS_LOG(EXCEPTION) << "device_num must be set when use shard function";
|
||||
}
|
||||
auto device_num_shard = parallel::ParallelContext::GetInstance()->device_num();
|
||||
SetStrategyForShard(root, all_nodes, device_num_shard);
|
||||
}
|
||||
|
||||
if (!HasVirtualDataset(all_nodes)) {
|
||||
InsertVirtualDataset(root, all_nodes);
|
||||
}
|
||||
|
@ -218,6 +485,7 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto global_rank = GetRank();
|
||||
auto world_group = GetWorldGroup();
|
||||
uint32_t world_rank_size = 0;
|
||||
|
@ -231,6 +499,7 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
} else {
|
||||
device_num = parallel::ParallelContext::GetInstance()->device_num();
|
||||
}
|
||||
|
||||
if (device_num < 1) {
|
||||
MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'device_num' must be positive, "
|
||||
"but got the value of device_num: "
|
||||
|
|
|
@ -861,11 +861,6 @@ void CheckPyNativeContext() {
|
|||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
const auto &ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
const auto ¶llel_mode = parallel_context->parallel_mode();
|
||||
if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
|
||||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
|
||||
}
|
||||
}
|
||||
|
||||
py::object GetDstType(const TypeId &type_id) {
|
||||
|
@ -2672,6 +2667,34 @@ std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, c
|
|||
return cell_id;
|
||||
}
|
||||
|
||||
void GradExecutor::MarkMsFunctionNodes(const pipeline::ResourcePtr &resource) {
|
||||
auto func_graph = resource->func_graph();
|
||||
std::vector<size_t> in_ms_function;
|
||||
auto parameters = func_graph->parameters();
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
auto param = parameters[i]->cast<ParameterPtr>();
|
||||
if (!param->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto iter = std::find(ms_function_params_.begin(), ms_function_params_.end(), param->name());
|
||||
if (iter != ms_function_params_.end()) {
|
||||
in_ms_function.push_back(1);
|
||||
} else {
|
||||
in_ms_function.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = func_graph->get_return();
|
||||
auto ret_cnode = ret->cast<CNodePtr>();
|
||||
auto grads = ret_cnode->input(1)->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < grads->inputs().size(); i++) {
|
||||
if (in_ms_function[i - 1]) {
|
||||
auto cnode = grads->input(i)->cast<CNodePtr>();
|
||||
cnode->set_parallel(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const py::object &weights, const py::object &grad_position, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
|
@ -2716,6 +2739,10 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
|
|||
compile::SetMindRTEnable();
|
||||
resource->SetResult(pipeline::kBackend, compile::CreateBackend());
|
||||
MS_LOG(DEBUG) << "Start task emit action";
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL) {
|
||||
MarkMsFunctionNodes(resource);
|
||||
}
|
||||
TaskEmitAction(resource);
|
||||
MS_LOG(DEBUG) << "Start execute action";
|
||||
ExecuteAction(resource);
|
||||
|
@ -3264,6 +3291,15 @@ py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &a
|
|||
FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
|
||||
MS_EXCEPTION_IF_NULL(grad_graph);
|
||||
GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL) {
|
||||
for (auto ¶meter : ms_func_graph->parameters()) {
|
||||
auto param = parameter->cast<ParameterPtr>();
|
||||
if (param->has_default()) {
|
||||
ms_function_params_.push_back(param->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
set_graph_phase("");
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -292,6 +292,7 @@ class GradExecutor {
|
|||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
graph_info->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void MarkMsFunctionNodes(const pipeline::ResourcePtr &resource);
|
||||
|
||||
private:
|
||||
bool grad_flag_{false};
|
||||
|
@ -311,6 +312,8 @@ class GradExecutor {
|
|||
TopCellInfoPtr top_cell_{nullptr};
|
||||
// Records forwrad cell, the bottom is top cell
|
||||
std::stack<std::string> cell_stack_;
|
||||
// Stores parameter in ms_function
|
||||
std::vector<std::string> ms_function_params_;
|
||||
// For high grad of bprop
|
||||
std::stack<std::pair<std::string, bool>> bprop_grad_stack_;
|
||||
std::vector<std::string> bprop_cell_list_;
|
||||
|
|
|
@ -454,7 +454,8 @@ bool HcclAdapter::FinalizeHcclComm() {
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (!task_sink) {
|
||||
auto execution_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
if (!task_sink && execution_mode == kGraphMode) {
|
||||
HcclCollectiveGroup::instance().DestroyCommGroup();
|
||||
}
|
||||
if (hccl_comm_ == nullptr) {
|
||||
|
|
|
@ -703,6 +703,16 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
|
|||
/// \param node A node debug info of an anf node.
|
||||
void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
|
||||
|
||||
/// \brief Check whether this node is in ms_function or not in PyNative Mode.
|
||||
///
|
||||
/// \return True if in ms_function, otherwise false.
|
||||
bool is_parallel() const { return is_parallel_; }
|
||||
|
||||
/// \brief Set is_parallel_ for CNode.
|
||||
///
|
||||
/// \param[in] is_parallel_ Boolean.
|
||||
void set_parallel(bool parallel) { is_parallel_ = parallel; }
|
||||
|
||||
private:
|
||||
std::vector<AnfNodePtr> inputs_;
|
||||
VarPtr func_graph_as_var_;
|
||||
|
@ -710,6 +720,8 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
|
|||
bool in_forward_flag_ = false;
|
||||
bool effect_handled_ = false;
|
||||
bool is_load_ = false;
|
||||
// is_parallel represents whether this cnode lies in ms_function or not in PyNative Mode
|
||||
bool is_parallel_ = false;
|
||||
// inputs_value_ store cnode input value and id in pynative mode
|
||||
// output_value_ store cnode value and id in pynative mode
|
||||
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
|
||||
|
|
|
@ -23,6 +23,7 @@ import ast
|
|||
import importlib
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
|
@ -30,6 +31,7 @@ from mindspore._extends.remote import kernel_build_server
|
|||
from .tensor import Tensor as MsTensor
|
||||
from .tensor import CSRTensor as MsCSRTensor
|
||||
from .tensor import COOTensor as MsCOOTensor
|
||||
from .initializer import initializer
|
||||
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, COOTensor, PynativeExecutor_
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
|
||||
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
|
@ -200,6 +202,7 @@ class _MindsporeFunctionExecutor:
|
|||
self.obj = None
|
||||
if obj and hasattr(obj, fn.__name__):
|
||||
self.obj = obj
|
||||
self.shard_parent_obj = obj
|
||||
self._graph_executor = GraphExecutor_.get_instance()
|
||||
self._create_time = ms_create_time
|
||||
|
||||
|
@ -223,6 +226,36 @@ class _MindsporeFunctionExecutor:
|
|||
if enable_compile_cache is True or enable_compile_cache == "1":
|
||||
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
||||
|
||||
def _parallel_process_for_ms_function(self, phase):
|
||||
"""Set parameter and optimizer states data according to sliced shape for shard"""
|
||||
obj = self.shard_parent_obj if self.obj is None else self.obj
|
||||
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
|
||||
obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
|
||||
replace = obj.init_parameters_data(auto_parallel_mode=True)
|
||||
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
|
||||
self._graph_executor.updata_param_node_default_input(phase, new_param)
|
||||
obj.load_parameter_slice(None)
|
||||
|
||||
|
||||
if _pynative_executor.get_optimizer():
|
||||
params = obj.trainable_params()
|
||||
opt_params = _pynative_executor.get_optimizer().trainable_params()
|
||||
opt_states = []
|
||||
for opt_param in opt_params:
|
||||
for param in params:
|
||||
if opt_param.name.find(param.name) > 0:
|
||||
opt_states.append(opt_param)
|
||||
obj.parameter_layout_dict[opt_param.name] = obj.parameter_layout_dict[param.name]
|
||||
continue
|
||||
|
||||
states_tuple = (opt_states[:len(params)], opt_states[len(params):]) if len(opt_states) != len(params) \
|
||||
else (opt_states[:len(params)],)
|
||||
for states in states_tuple:
|
||||
for param, state in zip(params, states):
|
||||
if param.shape != state.shape:
|
||||
state.set_data(initializer(state.init, param.shape), True)
|
||||
_pynative_executor.get_top_cell().parameter_layout_dict = obj.parameter_layout_dict
|
||||
|
||||
def compile(self, args_list, method_name):
|
||||
"""Returns pipeline for the given args."""
|
||||
# Verify the signature for both function and method
|
||||
|
@ -271,6 +304,10 @@ class _MindsporeFunctionExecutor:
|
|||
else:
|
||||
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
||||
is_compile = self._graph_executor.compile(self.obj, args_list, phase, True)
|
||||
|
||||
if is_pynative_parallel():
|
||||
self._parallel_process_for_ms_function(phase)
|
||||
|
||||
if not is_compile:
|
||||
raise RuntimeError("Executor compile failed.")
|
||||
if context.get_context("enable_ge"):
|
||||
|
@ -284,7 +321,18 @@ class _MindsporeFunctionExecutor:
|
|||
if self.obj is not None:
|
||||
args_list = args_list[1:]
|
||||
|
||||
phase = self.compile(args_list, self.fn.__name__)
|
||||
if is_pynative_parallel() and not hasattr(self.shard_parent_obj, "keep_input_unchanged"):
|
||||
device_num = context.get_auto_parallel_context('device_num')
|
||||
new_args_list = ()
|
||||
for arg in args_list:
|
||||
if isinstance(arg, MsTensor):
|
||||
new_shape = (arg.shape[0] * device_num,) + arg.shape[1:]
|
||||
new_args_list += (MsTensor(np.zeros(shape=new_shape), arg.dtype),)
|
||||
else:
|
||||
new_args_list += (arg,)
|
||||
phase = self.compile(new_args_list, self.fn.__name__)
|
||||
else:
|
||||
phase = self.compile(args_list, self.fn.__name__)
|
||||
|
||||
if context.get_context("precompile_only"):
|
||||
return None
|
||||
|
@ -371,6 +419,8 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
process_obj = None
|
||||
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
if process_obj is None and is_pynative_parallel():
|
||||
process_obj = obj
|
||||
out = _MindsporeFunctionExecutor(func, ms_create_time, input_signature, process_obj)(*args)
|
||||
return out
|
||||
|
||||
|
@ -380,6 +430,11 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
return wrap_mindspore(fn)
|
||||
return wrap_mindspore
|
||||
|
||||
def is_pynative_parallel():
|
||||
run_mode = context.get_context('mode')
|
||||
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
||||
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
|
||||
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
def _get_auto_split_param_names(parameter_layout_dict):
|
||||
auto_split_param_names = []
|
||||
|
@ -443,6 +498,8 @@ class _PynativeExecutor:
|
|||
self._executor = PynativeExecutor_.get_instance()
|
||||
self._executor.set_py_exe_path(sys.executable)
|
||||
self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
||||
self._optimizer = None
|
||||
self._top_cell = None
|
||||
|
||||
def new_graph(self, obj, *args, **kwargs):
|
||||
self._executor.new_graph(obj, *args, *(kwargs.values()))
|
||||
|
@ -508,6 +565,12 @@ class _PynativeExecutor:
|
|||
def set_hook_changed(self, cell):
|
||||
self._executor.set_hook_changed(cell)
|
||||
|
||||
def get_optimizer(self):
|
||||
return self._optimizer
|
||||
|
||||
def get_top_cell(self):
|
||||
return self._top_cell
|
||||
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
args = args + tuple(kwargs.values())
|
||||
return self._executor(obj, args)
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore import log as logger
|
|||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||
from mindspore.common.hook_handle import HookHandle
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.ops.composite import Shard
|
||||
from .. import context
|
||||
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
||||
from .._checkparam import Validator
|
||||
|
@ -380,10 +381,13 @@ class Cell(Cell_):
|
|||
self.parameter_broadcast_done = True
|
||||
|
||||
def run_construct(self, cast_inputs, kwargs):
|
||||
"""Run the construct function"""
|
||||
if self._enable_forward_pre_hook:
|
||||
cast_inputs = self.run_forward_pre_hook(cast_inputs)
|
||||
if self.enable_backward_hook:
|
||||
output = self.run_backward_hook(*cast_inputs)
|
||||
elif hasattr(self, "_shard_fn"):
|
||||
output = self._shard_fn(*cast_inputs, **kwargs)
|
||||
else:
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
if self._enable_forward_hook:
|
||||
|
@ -443,6 +447,33 @@ class Cell(Cell_):
|
|||
for prim in all_prims:
|
||||
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
||||
|
||||
def shard(self, in_axes, out_axes, device="Ascend", level=0):
|
||||
"""
|
||||
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
||||
generated by sharding propagation. In_axes and out_axes define the input and output layout respectively.
|
||||
In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of
|
||||
this input/output and None represents data_parallel.
|
||||
Note:
|
||||
Only effective in PYNATIVE_MODE and auto_parallel_context in either ParallelMode.AUTO_PARALLEL and
|
||||
search_mode = sharding_propagation or ParallelMode.SEMI_AUTO_PARALLEL.
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> class example(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> self.block1 = Block()
|
||||
>>> self.block2 = Block()
|
||||
>>> self.block2.shard(in_axes=(None, (2, 1)), out_axes=(None,))
|
||||
>>> # self.parallel_block = F.shard(self.block2, in_axes=(None, (2, 1)), out_axes=(None,))
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.block1(x)
|
||||
>>> x = self.block2(x)
|
||||
>>> return x
|
||||
"""
|
||||
shard_fn = Shard()
|
||||
fn = shard_fn(self, in_axes, out_axes, device, level)
|
||||
object.__setattr__(self, "_shard_fn", fn)
|
||||
return self
|
||||
|
||||
class CellGuard:
|
||||
def __enter__(self):
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
|
@ -485,6 +516,8 @@ class Cell(Cell_):
|
|||
# Run in PyNative mode.
|
||||
if _pynative_executor.is_top_cell():
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
_pynative_executor._optimizer = getattr(self, "optimizer", None)
|
||||
_pynative_executor._top_cell = self
|
||||
# There many Casts in parameter_broadcast. Enable lazy_build and build faster.
|
||||
self._do_parameter_broadcast()
|
||||
|
||||
|
|
|
@ -754,7 +754,10 @@ class Shard(Shard_):
|
|||
self.device = None
|
||||
self.level = None
|
||||
|
||||
def __call__(self, fn, in_axes, out_axes, device, level=0):
|
||||
def __call__(self, fn, in_axes, out_axes, device="Ascend", level=0):
|
||||
if context.get_context("mode") != context.PYNATIVE_MODE or \
|
||||
context.get_auto_parallel_context("parallel_mode") not in ["semi_auto_parallel", "auto_parallel"]:
|
||||
raise AssertionError(f"'Shard' only supports semi_auto/auto parallel under PyNative mode")
|
||||
if not isinstance(in_axes, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}")
|
||||
if not isinstance(out_axes, tuple):
|
||||
|
@ -771,7 +774,7 @@ class Shard(Shard_):
|
|||
return self.shard_fn
|
||||
shard_ = Shard()
|
||||
|
||||
@ms_function
|
||||
@ms_function(obj=fn)
|
||||
def after_shard(*args):
|
||||
return shard_(fn, in_axes, out_axes, device, level)(*args)
|
||||
|
||||
|
|
|
@ -345,7 +345,7 @@ def vjp(fn, inputs, v):
|
|||
return wrap_container(inputs, v)
|
||||
|
||||
shard_fn = Shard()
|
||||
def shard(fn, in_axes, out_axes, device, level=0):
|
||||
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
|
||||
return shard_fn(fn, in_axes, out_axes, device, level)
|
||||
|
||||
def narrow(inputs, axis, start, length):
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.common.api import _wrap_func
|
|||
from mindspore.log import _LogActionOnce
|
||||
from mindspore import context, log as logger
|
||||
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
||||
from mindspore.common.parameter import Parameter
|
||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from .._checkparam import Validator
|
||||
from . import signature as sig
|
||||
|
@ -283,6 +284,9 @@ class Primitive(Primitive_):
|
|||
|
||||
def __call__(self, *args):
|
||||
should_elim, output = self.check_elim(*args)
|
||||
for arg in args:
|
||||
if isinstance(arg, Parameter) and arg.has_init:
|
||||
arg.init_data()
|
||||
if should_elim:
|
||||
return output
|
||||
return _run_op(self, self.name, args)
|
||||
|
|
|
@ -355,11 +355,6 @@ class _AutoParallelContext:
|
|||
ValueError: If parallel mode is not supported.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
run_mode = context.get_context("mode")
|
||||
if run_mode == context.PYNATIVE_MODE and parallel_mode not in (
|
||||
context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE):
|
||||
raise ValueError(f"Pynative Only support STAND_ALONE and DATA_PARALLEL for ParallelMode, "
|
||||
f"but got {parallel_mode.upper()}.")
|
||||
ret = self._context_handle.set_parallel_mode(parallel_mode)
|
||||
if ret is False:
|
||||
raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', "
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops.operations.comm_ops import AllGather
|
||||
from mindspore.communication import GlobalComm
|
||||
from ..common import ms_function
|
||||
|
||||
_allgather_cell = None
|
||||
|
||||
|
@ -31,6 +32,7 @@ class AllGatherCell(Cell):
|
|||
|
||||
self.allgather = AllGather(group)
|
||||
|
||||
@ms_function()
|
||||
def construct(self, x):
|
||||
x = self.allgather(x)
|
||||
|
||||
|
|
|
@ -704,6 +704,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
# pipeline parallel need to be supported here later
|
||||
if mp_weight:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
|
||||
object.__setattr__(allgather_net, "keep_input_unchanged", True)
|
||||
elif opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
||||
|
@ -815,7 +816,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
enc_key = Validator.check_isinstance('enc_key', kwargs['enc_key'], bytes)
|
||||
enc_mode = 'AES-GCM'
|
||||
if 'enc_mode' in kwargs.keys():
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str)
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
||||
dataset = kwargs['dataset'] if 'dataset' in kwargs.keys() else None
|
||||
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
|
||||
else:
|
||||
|
@ -961,8 +962,8 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
write_data = raw_data + bytes(append_size)
|
||||
offset += (data_length + append_size)
|
||||
if is_encrypt():
|
||||
write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
||||
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
||||
f.write(write_data)
|
||||
|
||||
# save graph
|
||||
|
@ -973,9 +974,9 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
|
||||
len(kwargs.get('enc_key')),
|
||||
kwargs.get('enc_mode'))
|
||||
model_file.write(model_string)
|
||||
os.chmod(graph_file_name, stat.S_IRUSR)
|
||||
|
||||
|
@ -1000,7 +1001,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
net_dict = net.parameters_dict()
|
||||
model.ParseFromString(mindir_stream)
|
||||
|
||||
if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
|
||||
if 'dataset' in kwargs.keys() and kwargs.get('dataset') is not None:
|
||||
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
|
||||
dataset = kwargs['dataset']
|
||||
_save_dataset_to_mindir(model, dataset)
|
||||
|
@ -1035,8 +1036,8 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'), len(kwargs.get('enc_key')),
|
||||
kwargs.get('enc_mode'))
|
||||
f.write(model_string)
|
||||
os.chmod(file_name, stat.S_IRUSR)
|
||||
|
||||
|
@ -1109,8 +1110,8 @@ def _quant_export(network, *inputs, file_format, **kwargs):
|
|||
quant_net = copy.deepcopy(network)
|
||||
quant_net._create_time = int(time.time() * 1e9)
|
||||
|
||||
mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
|
||||
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']
|
||||
mean = 127.5 if kwargs.get('mean', None) is None else kwargs.get('mean')
|
||||
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs.get('std_dev')
|
||||
mean = Validator.check_value_type("mean", mean, (int, float))
|
||||
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class TestSharedParameterCast:
|
|||
"""
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
|
||||
interleaved_batch=1)
|
||||
self.cat_fp16_from_ir(target_count=27)
|
||||
self.cat_fp16_from_ir(target_count=28)
|
||||
|
||||
def test_optimizer_fp16_micro_batch(self):
|
||||
"""
|
||||
|
@ -127,7 +127,7 @@ class TestSharedParameterCast:
|
|||
"""
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
|
||||
interleaved_batch=2)
|
||||
self.cat_fp16_from_ir(target_count=41)
|
||||
self.cat_fp16_from_ir(target_count=42)
|
||||
|
||||
def test_optimizer_fp16_pipeline(self):
|
||||
"""
|
||||
|
@ -138,7 +138,7 @@ class TestSharedParameterCast:
|
|||
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
|
||||
interleaved_batch=1,
|
||||
stages=1, micro_size=1)
|
||||
self.cat_fp16_from_ir(target_count=27)
|
||||
self.cat_fp16_from_ir(target_count=28)
|
||||
|
||||
def test_optimizer_fp16_pipeline_micro_batch(self):
|
||||
"""
|
||||
|
@ -149,4 +149,4 @@ class TestSharedParameterCast:
|
|||
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
|
||||
interleaved_batch=2,
|
||||
stages=1, micro_size=1)
|
||||
self.cat_fp16_from_ir(target_count=41)
|
||||
self.cat_fp16_from_ir(target_count=42)
|
||||
|
|
Loading…
Reference in New Issue