From 735a6aaa3cfa51ee0d5e8e2025e38da135711e72 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Tue, 27 Apr 2021 11:34:57 +0800 Subject: [PATCH] add conv+add+relu & add+reluv2 fusion --- .../backend/kernel_compiler/tbe/tbe_kernel_build.cc | 9 ++++++--- .../ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h | 2 +- .../ascend/buffer_fusion/conv_double_in_fusion_pass.cc | 3 ++- .../ascend/buffer_fusion/eltwise_fusion_pass.cc | 3 +++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index de8336ca219..4698a04447b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -950,7 +950,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_ } (*output_desc)[kJFormat] = format; // special node - if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { + if ((fusion_data_type == kFusionAddN || fusion_data_type == kFusionAdd) && shape.size() == 5) { std::vector spec_shape = {}; spec_shape.emplace_back(shape[0]); spec_shape.emplace_back(shape[1]); @@ -995,7 +995,8 @@ void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr &reorder_layer, std::map *spec_data_input) { - if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) { + if ((op_name == kReluGradV2OpName || op_name == kAddNOpName || op_name == kTensorAddOpName) && + reorder_layer.empty()) { MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. "; return false; } @@ -1005,6 +1006,8 @@ bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, for (const auto &it : reorder_layer) { (*spec_data_input)[it] = kFusionAddN; } + } else if (op_name == kTensorAddOpName) { + (*spec_data_input)[reorder_layer[0]] = kFusionAdd; } return true; } @@ -1020,7 +1023,7 @@ bool TbeKernelBuild::GetInputLayers(const std::vector &in MS_EXCEPTION_IF_NULL(spec_data_input); auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { auto op_name = AnfAlgo::GetCNodeName(it); - return op_name == kConv2DBackpropInputOpName; + return (op_name == kConv2DBackpropInputOpName || op_name == kConv2DOpName); }); bool need_spec = (result != compute_nodes.end()); size_t input_size = 0; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h index 99b7504254a..26d3d220502 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -34,7 +34,7 @@ namespace kernel { // kernel operate type used for generate json class TbeKernelBuild { - enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 }; + enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2, kFusionAdd }; public: static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc index 3e6ca6663a6..679f674c24b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc @@ -66,7 +66,8 @@ void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE && + !AnfAlgo::CheckPrimitiveType(node, prim::kPrimReluV2)) { MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion); } } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc index f486e8dc9ce..46755b29708 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc @@ -42,6 +42,9 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne MS_EXCEPTION_IF_NULL(input_cnode); eltwise_input = input_cnode->input(1); } + if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) { + (void)record.insert(eltwise_input); + } if (record.size() < MIN_ELTWISE_SIZE) { return; }