forked from mindspore-Ecosystem/mindspore
!16478 handle load op in step parallel
From: @gong_zi_yan Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
1c8fda25ef
|
@ -139,30 +139,6 @@ std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node,
|
|||
return new_node_input;
|
||||
}
|
||||
|
||||
void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node,
|
||||
const FuncGraphPtr &func_graph, const std::string &instance_name) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input = CreateInput(op, pre_node, instance_name);
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
||||
new_node->set_in_forward_flag(true); // mark forward flag
|
||||
}
|
||||
auto new_node_value = node_input[0]->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_node_value);
|
||||
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
|
||||
new_node_prim->set_instance_name(instance_name);
|
||||
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->SetEdge(node, SizeToLong(index), new_node);
|
||||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
}
|
||||
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
||||
|
@ -256,15 +232,20 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|||
return new_node_input;
|
||||
}
|
||||
|
||||
void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index,
|
||||
const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name,
|
||||
const std::string ¶m_name) {
|
||||
void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node,
|
||||
const FuncGraphPtr &func_graph, const std::string &instance_name, const std::string ¶m_name = "",
|
||||
const FuncGraphPtr &root = nullptr) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
||||
std::vector<AnfNodePtr> node_input;
|
||||
if (root && !param_name.empty()) {
|
||||
node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
||||
} else {
|
||||
node_input = CreateInput(op, pre_node, instance_name);
|
||||
}
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
||||
|
@ -283,38 +264,19 @@ void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodeP
|
|||
|
||||
// Replace pre_node with pre_node->op
|
||||
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
||||
const std::string &instance_name) {
|
||||
const std::string &instance_name, const std::string ¶m_name = "",
|
||||
const FuncGraphPtr &root = nullptr) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = pre_node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input = CreateInput(op, pre_node, instance_name);
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
||||
new_node->set_in_forward_flag(true); // mark forward flag
|
||||
std::vector<AnfNodePtr> node_input;
|
||||
if (root && !param_name.empty()) {
|
||||
node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
||||
} else {
|
||||
node_input = CreateInput(op, pre_node, instance_name);
|
||||
}
|
||||
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
|
||||
new_node_prim->set_instance_name(instance_name);
|
||||
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->Replace(pre_node, new_node);
|
||||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
return new_node;
|
||||
}
|
||||
|
||||
// Replace pre_node with pre_node->op
|
||||
static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node,
|
||||
const FuncGraphPtr &func_graph, const std::string &instance_name,
|
||||
const std::string ¶m_name) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = pre_node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
||||
|
@ -918,6 +880,9 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|||
}
|
||||
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||
if (!cnode) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
|
||||
|
@ -1102,10 +1067,11 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
||||
if (!FindParameter(cnode->input(index), func_graph).first) {
|
||||
auto res = FindParameter(cnode->input(index), func_graph);
|
||||
if (!res.first) {
|
||||
continue;
|
||||
}
|
||||
return FindParameter(cnode->input(index), func_graph);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1126,10 +1092,11 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
|
||||
continue;
|
||||
}
|
||||
if (!FindParameter(cnode->input(index), func_graph).first) {
|
||||
auto res = FindParameter(cnode->input(index), func_graph);
|
||||
if (!res.first) {
|
||||
continue;
|
||||
}
|
||||
return FindParameter(cnode->input(index), func_graph);
|
||||
return res;
|
||||
}
|
||||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
@ -1160,11 +1127,14 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
|||
}
|
||||
MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph";
|
||||
}
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(use_apply)) {
|
||||
return FindCNode(node_pair.first, name, func_graph);
|
||||
}
|
||||
}
|
||||
return std::make_pair(result, cnode_return);
|
||||
}
|
||||
|
||||
bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
|
||||
bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
|
||||
// only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
|
||||
if (!ParallelContext::GetInstance()->gradient_fp32_sync()) {
|
||||
return false;
|
||||
|
@ -1175,11 +1145,10 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
|
|||
if (cnode == nullptr || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
auto pre_value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_value_node);
|
||||
auto pre_prim = pre_value_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_prim);
|
||||
if (pre_prim->name() != CAST) {
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(cnode)) {
|
||||
pre_node = cnode->input(1);
|
||||
}
|
||||
if (!IsPrimitiveCNode(pre_node, prim::kPrimCast)) {
|
||||
return false;
|
||||
}
|
||||
auto node_type = pre_node->Type();
|
||||
|
@ -1213,6 +1182,17 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no
|
|||
return true;
|
||||
}
|
||||
|
||||
// only used for InsertMirrorOps
|
||||
CNodePtr SkipTrivialNodes(CNodePtr node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
while (!IsSomePrimitive(node, LOAD)) {
|
||||
if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
|
||||
node = node->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t node_size = node->inputs().size();
|
||||
|
@ -1242,11 +1222,17 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
|
||||
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
|
||||
std::string param_name;
|
||||
bool is_shared_param = false;
|
||||
if (param_ptr) {
|
||||
param_name = param_ptr->name();
|
||||
if (!param_ptr->param_info() || !param_ptr->param_info()->requires_grad()) {
|
||||
MS_LOG(INFO) << param_name << " do not need gradient. Skip inserting mirror.";
|
||||
continue;
|
||||
}
|
||||
std::string opt_shard_mirror_group;
|
||||
if (param_ptr->user_data<TensorLayout>()) {
|
||||
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
|
||||
is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
|
||||
}
|
||||
if (!opt_shard_mirror_group.empty()) {
|
||||
// mirror ops is covered in not fully use opt shard case
|
||||
|
@ -1254,51 +1240,53 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
}
|
||||
}
|
||||
// not a RefKey
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
std::string mirror_op_name;
|
||||
if (grad_accumulation_step > 1) {
|
||||
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
|
||||
} else {
|
||||
mirror_op_name = MIRROR_OPERATOR;
|
||||
}
|
||||
AnfNodePtr pre_node = node->input(index);
|
||||
if (!param_node_pair.second) {
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
std::string mirror_op_name;
|
||||
if (grad_accumulation_step > 1) {
|
||||
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
|
||||
} else {
|
||||
mirror_op_name = MIRROR_OPERATOR;
|
||||
}
|
||||
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
|
||||
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
|
||||
if (next_cnode.first) {
|
||||
MS_EXCEPTION_IF_NULL(next_cnode.second);
|
||||
// param->cast->op, insert mirror before cast
|
||||
if (node->input(index)->isa<CNode>()) {
|
||||
auto pre_cnode = node->input(index)->cast<CNodePtr>();
|
||||
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
if ((pre_prim->name() == CAST) || (pre_prim->name() == LOAD)) {
|
||||
manager->SetEdge(pre_cnode, 1, next_cnode.second);
|
||||
continue;
|
||||
}
|
||||
// assume Load is inserted next to parameter
|
||||
// skip Load moving up and insert mirror next to the parameter
|
||||
if (pre_node->cast<CNodePtr>()) {
|
||||
CNodePtr load_node = SkipTrivialNodes(node->input(index)->cast<CNodePtr>());
|
||||
manager->SetEdge(load_node, 1, next_cnode.second);
|
||||
} else {
|
||||
manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
|
||||
}
|
||||
manager->SetEdge(node, SizeToLong(index), next_cnode.second);
|
||||
MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
|
||||
<< " and share the mirror.";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp
|
||||
// only one MirrorOp in backward_op
|
||||
if (backward_op.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
|
||||
MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
|
||||
}
|
||||
std::string instance_name = MIRROR_OP;
|
||||
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
|
||||
auto op = backward_op[0];
|
||||
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
|
||||
// insert new node before the node
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtr pre_node = cnode->input(1);
|
||||
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
|
||||
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
|
||||
if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) {
|
||||
// assume Load is inserted next to parameter
|
||||
// skip Load moving up and insert mirror next to the parameter
|
||||
CNodePtr load_node = SkipTrivialNodes(pre_node->cast<CNodePtr>());
|
||||
InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
|
||||
auto comm_op = load_node->input(1)->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||
MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
|
||||
<< " and insert mirror before Load";
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr pre_node = node->input(index);
|
||||
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
|
||||
InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
|
||||
MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
|
||||
<< " and insert mirror before the node";
|
||||
auto comm_op = node->input(index)->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
// pipeline mirror would not be set, which should be supported later
|
||||
|
@ -1635,34 +1623,64 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
|
|||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
||||
CNodePtr InsertAllGatherAfterCast(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
// skip Load moving down and assume it only has one node user
|
||||
CNodePtr res = cnode;
|
||||
if (IsSomePrimitive(res, LOAD)) {
|
||||
res = manager->node_users()[cnode].begin()->first->cast<CNodePtr>();
|
||||
}
|
||||
// return true only if cnode is Cast from fp32 to fp16
|
||||
if (!IsSomePrimitive(res, CAST)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto node_type = res->Type();
|
||||
MS_EXCEPTION_IF_NULL(node_type);
|
||||
if (!node_type->isa<mindspore::TensorType>()) {
|
||||
MS_LOG(EXCEPTION) << "Unknown type.";
|
||||
}
|
||||
auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(input_element_type);
|
||||
auto type_id = input_element_type->type_id();
|
||||
|
||||
if (type_id != kNumberTypeFloat32) {
|
||||
return res;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
|
||||
const AnfNodePtr &node, const std::string &op_name) {
|
||||
const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
|
||||
MS_EXCEPTION_IF_NULL(res.first);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = res.first->cast<CNodePtr>();
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(cnode_prim);
|
||||
Operator op;
|
||||
CNodePtr allgather;
|
||||
auto param_name = node->cast<ParameterPtr>()->name();
|
||||
if (op_name == MINI_STEP_ALL_GATHER) {
|
||||
op = CreateMiniStepAllGatherOp(group);
|
||||
auto param_name = node->cast<ParameterPtr>()->name();
|
||||
if (cnode_prim->name() == CAST) {
|
||||
allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
|
||||
} else {
|
||||
InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name);
|
||||
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
}
|
||||
} else {
|
||||
op = CreateAllGatherOp(group);
|
||||
if (cnode_prim->name() == CAST) {
|
||||
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
} else {
|
||||
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
}
|
||||
}
|
||||
CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
|
||||
if (!is_shared_param && cast_node) {
|
||||
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
|
||||
} else {
|
||||
InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name, root);
|
||||
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name;
|
||||
}
|
||||
// add fusion flag
|
||||
AddCommOpFusionType(allgather, node);
|
||||
|
@ -1676,6 +1694,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
return;
|
||||
}
|
||||
FuncGraphManagerPtr manager = root->manager();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
std::string op_name;
|
||||
|
@ -1692,28 +1711,25 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
if (cnode->in_forward_flag()) {
|
||||
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(WARNING) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
|
||||
MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
|
||||
} else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
|
||||
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
|
||||
<< distribute_operator->inputs_tensor_info().size();
|
||||
}
|
||||
if (insert_flag) {
|
||||
// if there are multiple node users, they share one same allgather
|
||||
auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph());
|
||||
if (next_cnode.first) {
|
||||
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
||||
<< GetPrimName(cnode);
|
||||
} else {
|
||||
// insert allgather operator between shard parameter and cnode
|
||||
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
||||
MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
|
||||
<< GetPrimName(cnode);
|
||||
}
|
||||
} else {
|
||||
// insert allgather operator between shard parameter and cnode
|
||||
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
|
||||
<< GetPrimName(cnode);
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
bool is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
|
||||
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name, is_shared_param);
|
||||
insert_flag = true;
|
||||
}
|
||||
}
|
||||
|
@ -1749,6 +1765,31 @@ static std::string GetOptShardGroup(const AnfNodePtr ¶meter, TensorLayout *c
|
|||
return opt_shard_group;
|
||||
}
|
||||
|
||||
void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
FuncGraphManagerPtr manager = root->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto parameter_ptr = parameter->cast<ParameterPtr>();
|
||||
if (!parameter_ptr) {
|
||||
MS_LOG(INFO) << parameter->ToString() << " is not a parameter";
|
||||
return;
|
||||
}
|
||||
auto param_sub_set = manager->node_users()[parameter];
|
||||
int32_t users_count = 0;
|
||||
for (auto ¶m_pair : param_sub_set) {
|
||||
auto cnode = param_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->in_forward_flag()) users_count++;
|
||||
}
|
||||
if (users_count > 1) {
|
||||
auto tensor_layout = parameter_ptr->user_data<TensorLayout>();
|
||||
tensor_layout->set_is_shared_param(true);
|
||||
MS_LOG(WARNING) << "There are multiple users for " << parameter->ToString()
|
||||
<< ". Mixed precision optimization is not valid here.";
|
||||
}
|
||||
}
|
||||
|
||||
// When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
|
||||
std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
|
@ -1801,6 +1842,7 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|||
if (iter != g_RefMap.end()) {
|
||||
std::string group = SetParallelShape(parameter, g_RefMap[parameter]);
|
||||
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
|
||||
SetSharedParameterFlag(root, parameter);
|
||||
ApplyParallelOptOnParam(root, parameter, group);
|
||||
continue;
|
||||
}
|
||||
|
@ -1810,6 +1852,7 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|||
} else {
|
||||
std::string group = SetParallelShape(parameter, res);
|
||||
// find all forward nodes that use parameter in graphs and insert allgather if group is not empty
|
||||
SetSharedParameterFlag(root, parameter);
|
||||
ApplyParallelOptOnParam(root, parameter, group);
|
||||
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
||||
}
|
||||
|
|
|
@ -116,6 +116,10 @@ class TensorLayout {
|
|||
|
||||
int32_t opt_weight_shard_size() { return opt_weight_shard_size_; }
|
||||
|
||||
void set_is_shared_param(bool is_shared_param) { is_shared_param_ = is_shared_param; }
|
||||
|
||||
bool is_shared_param() { return is_shared_param_; }
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "TLayout";
|
||||
|
||||
|
@ -145,6 +149,7 @@ class TensorLayout {
|
|||
std::string opt_shard_mirror_group_ = ""; // for mirror ops
|
||||
int32_t opt_weight_shard_step_ = 0;
|
||||
int32_t opt_weight_shard_size_ = 0;
|
||||
bool is_shared_param_ = false;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "Send", "UpdateState", "Load"};
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
|
||||
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};
|
||||
// clang-format on
|
||||
|
||||
bool IsInParallelBlackList(const PrimitivePtr &prim) {
|
||||
|
@ -48,6 +49,15 @@ bool IsInAllGatherNodeList(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsInTrivialNodeList(const CNodePtr &cnode) {
|
||||
for (auto &value : TRIVIAL_NODE_LIST_) {
|
||||
if (IsPrimitiveCNode(cnode, value)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsParallelConsiderCNode(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
namespace mindspore {
|
||||
bool IsInParallelBlackList(const PrimitivePtr &);
|
||||
bool IsInAllGatherNodeList(const CNodePtr &);
|
||||
bool IsInTrivialNodeList(const CNodePtr &);
|
||||
bool IsParallelConsiderCNode(const CNodePtr &);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_shared_param_and_mix_precision """
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P, functional as F
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(Net1, self).__init__()
|
||||
self.fc1 = P.MatMul().shard(strategy=strategy1)
|
||||
self.fc2 = P.MatMul().shard(strategy=strategy2)
|
||||
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
|
||||
self.p2 = Parameter(Tensor(np.ones([64, 48]).astype(np.float32)), name="weight2")
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.fc1(x, self.p1)
|
||||
x = self.fc2(x, self.p2)
|
||||
x = self.fc1(x, self.p1)
|
||||
return x - y
|
||||
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(Net2, self).__init__()
|
||||
self.fc1 = P.MatMul().shard(strategy=strategy1)
|
||||
self.fc2 = P.MatMul().shard(strategy=strategy2)
|
||||
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
|
||||
self.p2 = Parameter(Tensor(np.ones([64, 48]).astype(np.float32)), name="weight2")
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.fc1(F.cast(x, mstype.float16), F.cast(self.p1, mstype.float16))
|
||||
x = self.fc2(x, F.cast(self.p2, mstype.float16))
|
||||
x = self.fc1(F.cast(x, mstype.float32), self.p1)
|
||||
return x - y
|
||||
|
||||
|
||||
def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None, enable_parallel_optimizer=False,
|
||||
gradient_fp32_sync=True):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num,
|
||||
enable_parallel_optimizer=enable_parallel_optimizer,
|
||||
gradient_fp32_sync=gradient_fp32_sync)
|
||||
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 64]).astype(np.float32))
|
||||
net = net(strategy1, strategy2)
|
||||
net = _VirtualDatasetCell(net)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
|
||||
train_network.set_auto_parallel()
|
||||
train_network.set_train()
|
||||
_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
|
||||
context.reset_auto_parallel_context()
|
||||
return train_network
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_1():
|
||||
auto_parallel_compile_net("auto_parallel", 8, Net1)
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_2():
|
||||
# data parallel case
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, Net1, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_3():
|
||||
# parallel optimizer and mix precision case
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_4():
|
||||
# parallel optimizer and mix precision case
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), True, False)
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_5():
|
||||
# test not fully use parallel optimizer with mix precision case
|
||||
context.set_auto_parallel_context(optimizer_weight_shard_size=2)
|
||||
auto_parallel_compile_net("semi_auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), True)
|
Loading…
Reference in New Issue