!16109 fix ub_fusion graph circle problem
From: @yuchaojie Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @zhoufeng54
This commit is contained in:
commit
bec093c456
|
@ -362,24 +362,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CheckCircle(const session::KernelGraph &kernel_graph, const BufferFusionInfo_t &fusion_info) {
|
||||||
|
bool has_circle = false;
|
||||||
|
for (auto &inp : fusion_info.inputs_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(inp);
|
||||||
|
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
|
||||||
|
has_circle = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return has_circle;
|
||||||
|
}
|
||||||
|
|
||||||
void RemoveCircle(const session::KernelGraph &kernel_graph,
|
void RemoveCircle(const session::KernelGraph &kernel_graph,
|
||||||
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||||
std::vector<int64_t> fusion_ids;
|
std::vector<int64_t> fusion_ids;
|
||||||
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
|
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) {
|
||||||
bool has_circle = false;
|
bool has_circle = CheckCircle(kernel_graph, fusion_info);
|
||||||
for (auto &inp : fusion_info.inputs_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(inp);
|
|
||||||
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) {
|
|
||||||
has_circle = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_circle) {
|
if (has_circle) {
|
||||||
fusion_ids.emplace_back(fusion_id);
|
fusion_ids.emplace_back(fusion_id);
|
||||||
}
|
}
|
||||||
|
@ -440,7 +444,11 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
|
||||||
MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed";
|
MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
|
if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) {
|
||||||
|
MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion.";
|
||||||
|
} else {
|
||||||
|
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "End Buffer Fusion";
|
MS_LOG(DEBUG) << "End Buffer Fusion";
|
||||||
return change;
|
return change;
|
||||||
|
|
Loading…
Reference in New Issue