!16109 fix ub_fusion graph circle problem
From: @yuchaojie Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @zhoufeng54
This commit is contained in:
commit
bec093c456
|
@ -362,24 +362,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
|
|||
}
|
||||
}
|
||||
|
||||
bool CheckCircle(const session::KernelGraph &kernel_graph, const BufferFusionInfo_t &fusion_info) {
|
||||
bool has_circle = false;
|
||||
for (auto &inp : fusion_info.inputs_list) {
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
|
||||
has_circle = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return has_circle;
|
||||
}
|
||||
|
||||
void RemoveCircle(const session::KernelGraph &kernel_graph,
|
||||
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
std::vector<int64_t> fusion_ids;
|
||||
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
|
||||
bool has_circle = false;
|
||||
for (auto &inp : fusion_info.inputs_list) {
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
|
||||
has_circle = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool has_circle = CheckCircle(kernel_graph, fusion_info);
|
||||
if (has_circle) {
|
||||
fusion_ids.emplace_back(fusion_id);
|
||||
}
|
||||
|
@ -440,7 +444,11 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
|
|||
MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed";
|
||||
continue;
|
||||
}
|
||||
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
|
||||
if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) {
|
||||
MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion.";
|
||||
} else {
|
||||
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End Buffer Fusion";
|
||||
return change;
|
||||
|
|
Loading…
Reference in New Issue