From 06908165100304f642ce5b9c92aba5fe91837a91 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Sat, 8 May 2021 15:00:20 +0800 Subject: [PATCH] fix ub_fusion graph circle problem --- .../ascend/buffer_fusion/ub_pattern_fusion.cc | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index f0b1e19d630..fad2df47b44 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -362,24 +362,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vectorisa() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) { + continue; + } + + if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) { + has_circle = true; + break; + } + } + return has_circle; +} + void RemoveCircle(const session::KernelGraph &kernel_graph, std::unordered_map *buffer_fusion_infos) { MS_EXCEPTION_IF_NULL(buffer_fusion_infos); std::vector fusion_ids; for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) { - bool has_circle = false; - for (auto &inp : fusion_info.inputs_list) { - MS_EXCEPTION_IF_NULL(inp); - if (!inp->isa() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) { - continue; - } - - if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) { - has_circle = true; - break; - } - } - + bool has_circle = CheckCircle(kernel_graph, fusion_info); if (has_circle) { fusion_ids.emplace_back(fusion_id); } @@ -440,7 +444,11 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; continue; } - change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); + if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) { + MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion."; + } else { + change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); + } } MS_LOG(DEBUG) << "End Buffer Fusion"; return change;