From 3f67e891db2be7aed42a94f4ae98a06c236067c6 Mon Sep 17 00:00:00 2001 From: liyiqi Date: Fri, 3 Mar 2023 16:46:39 +0800 Subject: [PATCH] make different models use different comm groups --- .../device/ascend/optimizer/ascend_comm_op_reuse.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc index 4558dd3ba77..3967f620062 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc @@ -40,7 +40,7 @@ std::string VecToString(const std::vector &vec) { return res; } -std::string GenCommOpKey(const CNodePtr &node) { +std::string GenCommOpKey(const CNodePtr &node, const KernelGraphPtr &root_graph) { std::string op_key; MS_EXCEPTION_IF_NULL(node); auto comm_prim = GetCNodePrimitive(node); @@ -68,6 +68,8 @@ std::string GenCommOpKey(const CNodePtr &node) { if (comm_prim->HasAttr(kAttrRecvRankIds)) { op_key += "_" + VecToString(GetValue>(comm_prim->GetAttr(kAttrRecvRankIds))); } + // model identifier, aka. root_graph_id + op_key += "_" + std::to_string(root_graph->root_graph_id()); MS_LOG(INFO) << node->DebugString() << " key " << op_key; return op_key; } @@ -198,7 +200,7 @@ void AscendCommOpReuse::AnalyseCommOpReuse() { if (!IsReusable(comm_op)) { continue; } - reuse_map[GenCommOpKey(comm_op)].push_back(comm_op); + reuse_map[GenCommOpKey(comm_op, root_graph_)].push_back(comm_op); } for (const auto &[key, comm_op_set] : reuse_map) { @@ -255,7 +257,7 @@ KernelGraphPtr AscendCommOpReuse::CreateCommSubGraph(const CNodePtr &comm_op) { MS_EXCEPTION_IF_NULL(new_comm_op); new_comm_op->set_abstract(comm_op->abstract()); - std::string group_name = GenCommOpKey(comm_op); + std::string group_name = GenCommOpKey(comm_op, root_graph_); auto rank_list = common::AnfAlgo::GetNodeAttr>(comm_op, kAttrRankList); if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_list)) { MS_LOG(EXCEPTION) << "Create new group " << group_name << " failed, rank list = " << VecToString(rank_list);