forked from mindspore-Ecosystem/mindspore
!49695 fix comm subgraph bug
Merge pull request !49695 from lyqlola/comm
This commit is contained in:
commit
2bd43a73c9
|
@ -40,7 +40,7 @@ std::string VecToString(const std::vector<T> &vec) {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::string GenCommOpKey(const CNodePtr &node) {
|
||||
std::string GenCommOpKey(const CNodePtr &node, const KernelGraphPtr &root_graph) {
|
||||
std::string op_key;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto comm_prim = GetCNodePrimitive(node);
|
||||
|
@ -68,6 +68,8 @@ std::string GenCommOpKey(const CNodePtr &node) {
|
|||
if (comm_prim->HasAttr(kAttrRecvRankIds)) {
|
||||
op_key += "_" + VecToString(GetValue<std::vector<int64_t>>(comm_prim->GetAttr(kAttrRecvRankIds)));
|
||||
}
|
||||
// model identifier, aka. root_graph_id
|
||||
op_key += "_" + std::to_string(root_graph->root_graph_id());
|
||||
MS_LOG(INFO) << node->DebugString() << " key " << op_key;
|
||||
return op_key;
|
||||
}
|
||||
|
@ -198,7 +200,7 @@ void AscendCommOpReuse::AnalyseCommOpReuse() {
|
|||
if (!IsReusable(comm_op)) {
|
||||
continue;
|
||||
}
|
||||
reuse_map[GenCommOpKey(comm_op)].push_back(comm_op);
|
||||
reuse_map[GenCommOpKey(comm_op, root_graph_)].push_back(comm_op);
|
||||
}
|
||||
|
||||
for (const auto &[key, comm_op_set] : reuse_map) {
|
||||
|
@ -255,7 +257,7 @@ KernelGraphPtr AscendCommOpReuse::CreateCommSubGraph(const CNodePtr &comm_op) {
|
|||
MS_EXCEPTION_IF_NULL(new_comm_op);
|
||||
new_comm_op->set_abstract(comm_op->abstract());
|
||||
|
||||
std::string group_name = GenCommOpKey(comm_op);
|
||||
std::string group_name = GenCommOpKey(comm_op, root_graph_);
|
||||
auto rank_list = common::AnfAlgo::GetNodeAttr<std::vector<unsigned int>>(comm_op, kAttrRankList);
|
||||
if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_list)) {
|
||||
MS_LOG(EXCEPTION) << "Create new group " << group_name << " failed, rank list = " << VecToString(rank_list);
|
||||
|
|
Loading…
Reference in New Issue