forked from mindspore-Ecosystem/mindspore
!8557 run cast before allgather in parallel optimzier
From: @gong_zi_yan Reviewed-by: Signed-off-by:
This commit is contained in:
commit
2bf165c0b4
|
@ -55,6 +55,10 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
|||
// it will be one item in map with key: C, and value: (B, i)
|
||||
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||
static void HandleNoUsedParameter(const FuncGraphPtr &root);
|
||||
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
||||
const std::string &instance_name);
|
||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
const std::string &opt_shard_group);
|
||||
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
||||
if (new_node_input.empty()) {
|
||||
|
@ -125,6 +129,30 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
}
|
||||
|
||||
// Replace pre_node with pre_node->op
|
||||
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
||||
const std::string &instance_name) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = pre_node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input = CreateInput(op, pre_node, instance_name);
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
||||
new_node->set_in_forward_flag(true); // mark forward flag
|
||||
}
|
||||
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
|
||||
new_node_prim->set_instance_name(instance_name);
|
||||
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->Replace(pre_node, new_node);
|
||||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
return new_node;
|
||||
}
|
||||
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
|
@ -1380,17 +1408,25 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int
|
|||
auto cnode = res.first->cast<CNodePtr>();
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(cnode_prim);
|
||||
CNodePtr allgather;
|
||||
if (cnode_prim->name() == CAST) {
|
||||
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
} else {
|
||||
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
}
|
||||
// add fusion flag
|
||||
auto allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(allgather);
|
||||
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
||||
auto attrs = prim->attrs();
|
||||
// enable fusion flag later when it's supported in backend
|
||||
attrs["fusion"] = MakeValue(0);
|
||||
attrs["fusion"] = MakeValue<int64_t>(0);
|
||||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
const std::string &opt_shard_group) {
|
||||
if (opt_shard_group.empty()) {
|
||||
return;
|
||||
|
|
|
@ -119,3 +119,12 @@ def test_neg_repeat_calc2():
|
|||
strategy2 = ((4, 4),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_parallel_optimizer_with_mix_precision():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
|
||||
enable_parallel_optimizer=True)
|
||||
strategy1 = ((8, 1), (8, 1))
|
||||
strategy2 = ((8, 1),)
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue