From 891d7f3b3065b968f8aca68f9e8d08e511f58eca Mon Sep 17 00:00:00 2001 From: wangyanling Date: Mon, 10 Jan 2022 17:33:17 +0800 Subject: [PATCH] add fullconnected fusion with add --- .../lite/tools/converter/anf_transform.cc | 2 + .../lite/tools/optimizer/common/gllo_utils.cc | 19 ++ .../lite/tools/optimizer/common/gllo_utils.h | 2 + .../fusion/fullconnected_add_fusion.cc | 197 ++++++++++++++++++ .../fusion/fullconnected_add_fusion.h | 40 ++++ .../optimizer/fusion/fullconnected_fusion.cc | 1 - .../optimizer/fusion/matmul_add_fusion.cc | 21 +- 7 files changed, 261 insertions(+), 21 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.h diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 8c479139405..fb4e38209a2 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -58,6 +58,7 @@ #include "tools/optimizer/fusion/scale_activation_fusion.h" #include "tools/optimizer/fusion/scale_scale_fusion.h" #include "tools/optimizer/fusion/fullconnected_fusion.h" +#include "tools/optimizer/fusion/fullconnected_add_fusion.h" #include "tools/optimizer/fusion/add_concat_activation_fusion.h" #include "tools/optimizer/fusion/matmul_activation_fusion.h" #include "tools/optimizer/fusion/activation_fusion.h" @@ -219,6 +220,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(fusion_pm); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 273585031fd..442729ae4c3 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -1126,5 +1126,24 @@ int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, boo *infer_succ = infer_infos[item_index]; return RET_OK; } +bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type) { + MS_CHECK_TRUE_RET(cnode != nullptr, false); + MS_CHECK_TRUE_RET(index != nullptr, false); + if (cnode->size() != kInputSizeThree) { + return false; + } + size_t dst_index = 0; + for (size_t i = 1; i < cnode->size(); ++i) { + if (CheckPrimitiveType(cnode->input(i), primitive_type)) { + dst_index = i; + break; + } + } + if (dst_index == 0) { + return false; + } + *index = dst_index; + return true; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index a520e82a74f..fc39fc12b10 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -141,6 +141,8 @@ std::pair GetRealCertainVarInput(const CNodePtr &cnode, size_t in int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ); +bool CheckAndGetCnodeIndex(const CNodePtr &cnode, size_t *index, const PrimitivePtr &primitive_type); + template inline bool IsSpecifiedNode(const BaseRef &n) { if (utils::isa(n)) { diff --git a/mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.cc b/mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.cc new file mode 100644 index 00000000000..646f0d7deee --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/fullconnected_add_fusion.h" +#include +#include +#include "ops/fusion/add_fusion.h" +#include "ops/fusion/full_connection.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "nnacl/op_base.h" + +namespace mindspore { +namespace opt { +namespace { +bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &fc_cnode, int index) { + auto add_primc = GetValueNode(add_cnode->input(0)); + MS_CHECK_TRUE_RET(add_primc != nullptr, false); + if (IsQuantParameterNode(add_primc)) { + MS_LOG(INFO) << add_cnode->fullname_with_scope() << " is quant node"; + return false; + } + + auto add_param_node = add_cnode->input(kInputSizeThree - index); + if (!utils::isa(add_param_node) && + (!utils::isa(add_param_node) || !add_param_node->cast()->default_param())) { + return false; + } + auto abstract = add_param_node->abstract(); + MS_CHECK_TRUE_RET(abstract != nullptr, false); + std::vector bias_shape; + if (FetchShapeFromAbstract(abstract, &bias_shape) != lite::RET_OK) { + MS_LOG(ERROR) << "Fetch shape from abstract failed."; + return false; + } + if (bias_shape.size() > DIMENSION_1D) { + MS_LOG(INFO) << "only support bias with shape size of 1."; + return false; + } + + if (fc_cnode->size() > kInputSizeThree) { + auto fc_bias_node = fc_cnode->input(kInputIndexThree); + if (!utils::isa(fc_bias_node) && + (!utils::isa(fc_bias_node) || !fc_bias_node->cast()->default_param())) { + MS_LOG(INFO) << fc_cnode->fullname_with_scope() << "'s bias is not parameter"; + return false; + } + } + auto fc_primc = GetValueNode>(fc_cnode->input(0)); + MS_CHECK_TRUE_RET(fc_primc != nullptr, false); + if (fc_primc->GetAttr(ops::kActivationType) != nullptr && + fc_primc->get_activation_type() != ActivationType::NO_ACTIVATION) { + MS_LOG(INFO) << fc_cnode->fullname_with_scope() << " has activation attr"; + return false; + } + if (IsQuantParameterNode(fc_primc)) { + MS_LOG(INFO) << fc_cnode->fullname_with_scope() << "is quant node"; + return false; + } + + return true; +} + +int CalNewCnodeBias(const AnfNodePtr &add_weight_node, const CNodePtr &fc_cnode) { + MS_CHECK_TRUE_RET(add_weight_node != nullptr, RET_ERROR); + MS_CHECK_TRUE_RET(fc_cnode != nullptr, RET_ERROR); + auto fc_bias_node = fc_cnode->input(kInputIndexThree); + MS_CHECK_TRUE_RET(fc_bias_node != nullptr, RET_ERROR); + std::shared_ptr fc_bias_tensor = GetTensorInfo(fc_bias_node); + MS_CHECK_TRUE_RET(fc_bias_tensor != nullptr, RET_ERROR); + if (fc_bias_tensor->data_type() != kNumberTypeFloat32) { + MS_LOG(INFO) << "only support float32 data type"; + return RET_ERROR; + } + std::vector fc_bias_shape = fc_bias_tensor->shape(); + auto fc_bias_data = reinterpret_cast(fc_bias_tensor->data_c()); + MS_CHECK_TRUE_RET(fc_bias_data != nullptr, RET_ERROR); + + std::shared_ptr add_weight_tensor = GetTensorInfo(add_weight_node); + MS_CHECK_TRUE_RET(add_weight_tensor != nullptr, RET_ERROR); + if (add_weight_tensor->data_type() != kNumberTypeFloat32) { + MS_LOG(INFO) << "only support float32 data type"; + return RET_ERROR; + } + std::vector add_weight_shape = add_weight_tensor->shape(); + MS_CHECK_TRUE_RET(fc_bias_shape == add_weight_shape, RET_ERROR); + auto add_weight_data = reinterpret_cast(add_weight_tensor->data_c()); + MS_CHECK_TRUE_RET(add_weight_data != nullptr, RET_ERROR); + + for (int64_t i = 0; i < fc_bias_shape[0]; ++i) { + fc_bias_data[i] += add_weight_data[i]; + } + return RET_OK; +} +} // namespace + +VectorRef FullconnectedAddFusion::DefineFcAddFusionPattern() const { + auto is_fc1 = std::make_shared(IsSpecifiedNode<&prim::kPrimFullConnection>); + MS_CHECK_TRUE_RET(is_fc1 != nullptr, {}); + auto is_add = std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>); + MS_CHECK_TRUE_RET(is_add != nullptr, {}); + auto is_seq_var = std::make_shared(); + MS_CHECK_TRUE_RET(is_seq_var != nullptr, {}); + return VectorRef({is_add, is_fc1, is_seq_var}); +} + +VectorRef FullconnectedAddFusion::DefineFcBiasAddPattern() const { + auto is_fc1 = std::make_shared(IsSpecifiedNode<&prim::kPrimFullConnection>); + MS_CHECK_TRUE_RET(is_fc1 != nullptr, {}); + auto is_bias_add = std::make_shared(IsSpecifiedNode<&prim::kPrimBiasAdd>); + MS_CHECK_TRUE_RET(is_bias_add != nullptr, {}); + auto is_seq_var = std::make_shared(); + MS_CHECK_TRUE_RET(is_seq_var != nullptr, {}); + return VectorRef({is_bias_add, is_fc1, is_seq_var}); +} + +std::unordered_map FullconnectedAddFusion::DefinePatterns() const { + std::unordered_map patterns; + patterns["FcAddFusionPatternName"] = DefineFcAddFusionPattern(); + patterns["FcBiasAddPatternName"] = DefineFcBiasAddPattern(); + return patterns; +} + +AnfNodePtr FullconnectedAddFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, + const AnfNodePtr &node, const EquivPtr &equiv) const { + if (func_graph == nullptr || node == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + + auto add_cnode = node->cast(); + MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr); + if (IsMarkedTrainOp(add_cnode)) { + return nullptr; + } + if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) { + return nullptr; + } + + size_t index = 0; + if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimFullConnection)) { + return nullptr; + } + auto fc_cnode = add_cnode->input(index)->cast(); + MS_ASSERT(fc_cnode != nullptr); + if (IsMarkedTrainOp(fc_cnode)) { + return nullptr; + } + + if (IsMultiOutputTensors(func_graph, fc_cnode)) { + return nullptr; + } + + if (!IsPrimitiveProper(add_cnode, fc_cnode, index)) { + return nullptr; + } + + auto manager = func_graph->manager(); + auto add_param_node = add_cnode->input(kInputSizeThree - index); + MS_CHECK_TRUE_RET(manager != nullptr, nullptr); + if (fc_cnode->size() == kInputSizeThree) { + manager->AddEdge(fc_cnode, add_param_node); + } else if (fc_cnode->size() == kInputSizeFour) { + if (CalNewCnodeBias(add_param_node, fc_cnode) != RET_OK) { + MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with " << fc_cnode->fullname_with_scope(); + return nullptr; + } + } + + if (CheckPrimitiveType(node, prim::kPrimAddFusion)) { + auto add_primc = GetValueNode>(add_cnode->input(0)); + MS_CHECK_TRUE_RET(add_primc != nullptr, nullptr); + if (add_primc->GetAttr(ops::kActivationType) != nullptr && + add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) { + auto fc_primc = GetValueNode>(fc_cnode->input(0)); + MS_CHECK_TRUE_RET(fc_primc != nullptr, nullptr); + fc_primc->set_activation_type(add_primc->get_activation_type()); + } + } + fc_cnode->set_fullname_with_scope(node->fullname_with_scope()); + (void)manager->Replace(node, fc_cnode); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.h b/mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.h new file mode 100644 index 00000000000..3bc2b88da98 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/fullconnected_add_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "tools/optimizer/common/multiple_pattern_process_pass.h" + +namespace mindspore::opt { +class FullconnectedAddFusion : public MultiplePatternProcessPass { + public: + explicit FullconnectedAddFusion(const std::string &name = "FullconnectedAddFusion", bool multigraph = true) + : MultiplePatternProcessPass(name, multigraph) {} + ~FullconnectedAddFusion() override = default; + + private: + std::unordered_map DefinePatterns() const override; + VectorRef DefineFcAddFusionPattern() const; + VectorRef DefineFcBiasAddPattern() const; + AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &, + const EquivPtr &) const override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FULLCONNECTED_ADD_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/fullconnected_fusion.cc b/mindspore/lite/tools/optimizer/fusion/fullconnected_fusion.cc index 4bba976c44d..241b2275f51 100644 --- a/mindspore/lite/tools/optimizer/fusion/fullconnected_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/fullconnected_fusion.cc @@ -19,7 +19,6 @@ #include #include "tools/common/tensor_util.h" #include "ops/fusion/full_connection.h" -#include "ops/fusion/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quant_param_holder.h" #include "nnacl/op_base.h" diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc index d97e9da98b0..ca1e22408ee 100644 --- a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc @@ -25,25 +25,6 @@ namespace mindspore { namespace opt { namespace { -bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { - MS_ASSERT(cnode != nullptr && index != nullptr); - if (cnode->size() != kInputSizeThree) { - return false; - } - size_t matmul_index = 0; - for (size_t i = 1; i < cnode->size(); ++i) { - if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMulFusion)) { - matmul_index = i; - break; - } - } - if (matmul_index == 0) { - return false; - } - *index = matmul_index; - return true; -} - bool IsPrimitiveProper(const CNodePtr &add_cnode, const CNodePtr &matmul_cnode, int index) { auto add_primc = GetValueNode(add_cnode->input(0)); MS_CHECK_TRUE_RET(add_primc != nullptr, false); @@ -169,7 +150,7 @@ AnfNodePtr MatMulAddFusion::Process(const std::string &pattern_name, const FuncG } size_t index = 0; - if (!CheckAndGetMatMulIndex(add_cnode, &index)) { + if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) { return nullptr; } auto matmul_cnode = add_cnode->input(index)->cast();