!8557 run cast before allgather in parallel optimzier

From: @gong_zi_yan
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-16 09:55:44 +08:00 committed by Gitee
commit 2bf165c0b4
2 changed files with 50 additions and 5 deletions

View File

@ -55,6 +55,10 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
static void HandleNoUsedParameter(const FuncGraphPtr &root);
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
const std::string &instance_name);
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
const std::string &opt_shard_group);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
@ -125,6 +129,30 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
MS_LOG(INFO) << "Insert " << instance_name << " success";
}
// Replace pre_node with pre_node->op
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
const std::string &instance_name) {
// insert new node before the node
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ScopePtr scope = pre_node->scope();
MS_EXCEPTION_IF_NULL(scope);
std::vector<AnfNodePtr> node_input = CreateInput(op, pre_node, instance_name);
CNodePtr new_node = func_graph->NewCNode(node_input);
MS_EXCEPTION_IF_NULL(new_node);
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
new_node->set_in_forward_flag(true); // mark forward flag
}
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->Replace(pre_node, new_node);
MS_LOG(INFO) << "Insert " << instance_name << " success";
return new_node;
}
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
@ -1380,18 +1408,26 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int
auto cnode = res.first->cast<CNodePtr>();
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(cnode_prim);
CNodePtr allgather;
if (cnode_prim->name() == CAST) {
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER);
} else {
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
// add fusion flag
auto allgather = cnode->input(res.second)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(allgather);
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
// enable fusion flag later when it's supported in backend
attrs["fusion"] = MakeValue(0);
attrs["fusion"] = MakeValue<int64_t>(0);
prim->SetAttrs(attrs);
}
void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
const std::string &opt_shard_group) {
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,
const std::string &opt_shard_group) {
if (opt_shard_group.empty()) {
return;
}

View File

@ -119,3 +119,12 @@ def test_neg_repeat_calc2():
strategy2 = ((4, 4),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_parallel_optimizer_with_mix_precision():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
enable_parallel_optimizer=True)
strategy1 = ((8, 1), (8, 1))
strategy2 = ((8, 1),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)