Merge pull request !41602 from lyqlola/fix
This commit is contained in:
i-robot 2022-09-16 06:55:42 +00:00 committed by Gitee
commit 04bd77f793
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 10 additions and 5 deletions

View File

@ -1916,6 +1916,11 @@ void AscendStreamAssign::InsertEventHcomDependHcomAtSameGroup(
}
void AscendStreamAssign::InsertEventForCallCommSubGraph(const NotNull<KernelGraphPtr> &graph_ptr) const {
if (comm_sub_graph_stream_.empty()) {
MS_LOG(INFO) << "No comm sub graph, skip.";
return;
}
std::map<uint32_t, std::string> comm_sub_graph_id_to_group = {}; // key: label id, value: hcom group
const auto &cnode_list = graph_ptr->execution_order();
for (const auto &n : cnode_list) {
@ -1948,7 +1953,7 @@ void AscendStreamAssign::InsertEventForCallCommSubGraph(const NotNull<KernelGrap
// insert event
std::map<uint32_t, std::string>::const_iterator label_iter = comm_sub_graph_id_to_group.find(label_id);
if (label_iter == comm_sub_graph_id_to_group.cend()) {
MS_LOG(WARNING) << "Cannot find comm group for sub comm graph label id " << label_id;
MS_LOG(INFO) << "Cannot find comm group for sub comm graph label id " << label_id;
new_order.push_back(n);
continue;
}

View File

@ -261,6 +261,10 @@ void AscendGraphOptimization::RootGraphExecutorValidate(const NotNull<KernelGrap
void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (memo_.find(graph) != memo_.end()) {
return;
}
(void)memo_.insert(graph);
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -270,10 +274,6 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
DumpIR(file_name, graph, true, kTopStack);
}
#endif
if (memo_.find(graph) != memo_.end()) {
return;
}
(void)memo_.insert(graph);
MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id();
SetOperatorInfo(graph);
MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id();