From de843b45b6975850d9b13259c885095d2b5e6307 Mon Sep 17 00:00:00 2001 From: huanghui Date: Fri, 26 Feb 2021 09:55:28 +0800 Subject: [PATCH] add circle check in ub fusion --- .../bnupdate_eltwise_eltwise_fusion_pass.cc | 4 --- .../ascend/buffer_fusion/ub_pattern_fusion.cc | 26 +++++++++++++++++++ .../ccsrc/backend/optimizer/common/helper.cc | 21 --------------- .../backend/optimizer/common/pass_manager.cc | 7 ++--- 4 files changed, 30 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index da885a4591d..03bb2cfd944 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -63,10 +63,6 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod auto bnupdate = getitem->input(1); MS_EXCEPTION_IF_NULL(bnupdate); if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - if (cnode->size() == ELTWISE_DOUBLE_IN_INPUT_SIZE && - IsDepend(kernel_graph, cnode->input(2), {relu_input, bnupdate})) { - return; - } std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); for (auto out_getitem : manager->node_users()[bnupdate]) { MS_EXCEPTION_IF_NULL(out_getitem.first); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index 6457098627f..122d5a31094 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -27,6 +27,7 @@ #include "base/core_ops.h" #include "runtime/device/kernel_info.h" #include "utils/ms_context.h" +#include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { @@ -353,6 +354,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) { + bool has_circle = false; + for (const auto &inp : fusion_info.inputs_list) { + if (!inp->isa() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) { + continue; + } + + if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) { + has_circle = true; + break; + } + } + + if (has_circle) { + buffer_fusion_infos->erase(fusion_id); + } + } +} } // namespace void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, @@ -361,6 +384,9 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); + // Remove the fusion infos which will produce a circle if do fusion + RemoveCircle(*kernel_graph, buffer_fusion_infos); + for (auto &buffer_fusion_info : *buffer_fusion_infos) { buffer_fusion_info.second.kernel_build_info = CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 5c95bd2b76b..5414240d481 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -49,23 +49,6 @@ std::vector Convert2Long(const std::vector &v) { bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector &nodes) { MS_EXCEPTION_IF_NULL(node); - std::vector node_list = TopoSort(graph.get_return()); - std::map> control_depend_map; - for (auto &nd : node_list) { - MS_EXCEPTION_IF_NULL(nd); - if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { - auto control_depend = nd->cast(); - auto prior_node = control_depend->input(kControlDependPriorIndex); - auto behind_node = control_depend->input(kControlDependBehindIndex); - auto it = control_depend_map.find(behind_node); - if (it == control_depend_map.end()) { - control_depend_map[behind_node] = std::set{prior_node}; - } else { - it->second.insert(prior_node); - } - } - } - FuncGraphManagerPtr manager = graph.manager(); MS_EXCEPTION_IF_NULL(manager); @@ -88,10 +71,6 @@ bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector< auto inputs = cnode->inputs(); (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); } - auto it = control_depend_map.find(nd); - if (it != control_depend_map.end()) { - (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); - } } return false; } diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc index 3548555c190..d654f06d21e 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc @@ -61,12 +61,13 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; #else (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + // time unit: us + uint64_t cost = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); cost += static_cast(end_time.tv_usec - start_time.tv_usec); MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; #endif - if (save_graphs) { + static const auto enable_dump = (common::GetEnv("ENV_NO_DUMP_BE_PASS_IR") != "1"); + if (save_graphs && enable_dump) { std::ostringstream oss; oss << "verbose_ir_files" << "/";