!49695 fix comm subgraph bug

Merge pull request !49695 from lyqlola/comm
This commit is contained in:
i-robot 2023-03-06 03:03:26 +00:00 committed by Gitee
commit 2bd43a73c9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 3 deletions

View File

@ -40,7 +40,7 @@ std::string VecToString(const std::vector<T> &vec) {
return res;
}
std::string GenCommOpKey(const CNodePtr &node) {
std::string GenCommOpKey(const CNodePtr &node, const KernelGraphPtr &root_graph) {
std::string op_key;
MS_EXCEPTION_IF_NULL(node);
auto comm_prim = GetCNodePrimitive(node);
@ -68,6 +68,8 @@ std::string GenCommOpKey(const CNodePtr &node) {
if (comm_prim->HasAttr(kAttrRecvRankIds)) {
op_key += "_" + VecToString(GetValue<std::vector<int64_t>>(comm_prim->GetAttr(kAttrRecvRankIds)));
}
// model identifier, aka. root_graph_id
op_key += "_" + std::to_string(root_graph->root_graph_id());
MS_LOG(INFO) << node->DebugString() << " key " << op_key;
return op_key;
}
@ -198,7 +200,7 @@ void AscendCommOpReuse::AnalyseCommOpReuse() {
if (!IsReusable(comm_op)) {
continue;
}
reuse_map[GenCommOpKey(comm_op)].push_back(comm_op);
reuse_map[GenCommOpKey(comm_op, root_graph_)].push_back(comm_op);
}
for (const auto &[key, comm_op_set] : reuse_map) {
@ -255,7 +257,7 @@ KernelGraphPtr AscendCommOpReuse::CreateCommSubGraph(const CNodePtr &comm_op) {
MS_EXCEPTION_IF_NULL(new_comm_op);
new_comm_op->set_abstract(comm_op->abstract());
std::string group_name = GenCommOpKey(comm_op);
std::string group_name = GenCommOpKey(comm_op, root_graph_);
auto rank_list = common::AnfAlgo::GetNodeAttr<std::vector<unsigned int>>(comm_op, kAttrRankList);
if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_list)) {
MS_LOG(EXCEPTION) << "Create new group " << group_name << " failed, rank list = " << VecToString(rank_list);