!16109 fix ub_fusion graph circle problem

From: @yuchaojie
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-05-31 19:43:35 +08:00 committed by Gitee
commit bec093c456
1 changed files with 22 additions and 14 deletions

View File

@ -362,24 +362,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
}
}
bool CheckCircle(const session::KernelGraph &kernel_graph, const BufferFusionInfo_t &fusion_info) {
bool has_circle = false;
for (auto &inp : fusion_info.inputs_list) {
MS_EXCEPTION_IF_NULL(inp);
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
continue;
}
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
has_circle = true;
break;
}
}
return has_circle;
}
void RemoveCircle(const session::KernelGraph &kernel_graph,
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
std::vector<int64_t> fusion_ids;
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
bool has_circle = false;
for (auto &inp : fusion_info.inputs_list) {
MS_EXCEPTION_IF_NULL(inp);
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
continue;
}
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
has_circle = true;
break;
}
}
bool has_circle = CheckCircle(kernel_graph, fusion_info);
if (has_circle) {
fusion_ids.emplace_back(fusion_id);
}
@ -440,7 +444,11 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed";
continue;
}
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) {
MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion.";
} else {
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
}
}
MS_LOG(DEBUG) << "End Buffer Fusion";
return change;