From 4e49de536eef442652743fe1e982231d0c42472a Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 29 Sep 2021 17:20:12 +0800 Subject: [PATCH] strengthen constant fold --- .../lite/tools/converter/anf_transform.cc | 6 +- .../fusion/constant_folding_fusion.cc | 83 +++++++++++++++---- .../fusion/constant_folding_fusion.h | 14 ++-- 3 files changed, 81 insertions(+), 22 deletions(-) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index f2e416bfe7d..b4013b16c0f 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -190,7 +190,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared(config->fmk)); + fusion_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); if (config->fmk == converter::kFmkTypeMs && !config->trainModel) { @@ -329,7 +329,7 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte CHECK_NULL_RETURN(const_fold_pm); const_fold_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); if (!config->trainModel) { - const_fold_pm->AddPass(std::make_shared(config->fmk)); + const_fold_pm->AddPass(std::make_shared(config->fmk, config->trainModel)); } const_fold_pm->AddPass(std::make_shared()); const_fold_pm->AddPass(std::make_shared()); @@ -511,9 +511,9 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) { auto is_train = config->trainModel; std::unordered_map passes = { {"DumpGraph", std::make_shared(config)}, - {"ConstFoldPass", std::make_shared(config->fmk)}, {"ToNCHWFormat", std::make_shared(fmk, is_train)}, {"ToNHWCFormat", std::make_shared(fmk, is_train)}, + {"ConstFoldPass", std::make_shared(fmk, is_train)}, {"InferShapePass", std::make_shared(fmk, is_train)}, {"DeleteRedundantTranspose", std::make_shared()}, {"SpecialNodePostProcess", std::make_shared()}, diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index c4954689f86..c06775efbea 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -39,6 +39,7 @@ using mindspore::lite::Tensor; namespace mindspore::opt { namespace { constexpr size_t INITIAL_SIZE = 1024; +constexpr auto kIsLinkWithControlFlow = "link_with_control_flow"; void FreeTensors(std::vector *input_tensor, std::vector *output_tensor) { if (input_tensor != nullptr) { for (auto &i : *input_tensor) { @@ -283,7 +284,11 @@ bool ConstFoldPass::Run(const FuncGraphPtr &func_graph) { return false; } std::set has_visited; - if (Process(func_graph, &has_visited) != lite::RET_OK) { + if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold pass failed,"; + return false; + } + if (HandleSpecialFold(func_graph) != lite::RET_OK) { MS_LOG(ERROR) << "do constant fold pass failed,"; return false; } @@ -306,7 +311,7 @@ bool ConstFoldPass::Init(const FuncGraphPtr &func_graph) { return true; } -int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set *has_visited) { +int ConstFoldPass::HandleCommonFold(const FuncGraphPtr &func_graph, std::set *has_visited) { MS_ASSERT(func_graph != nullptr); if (has_visited->find(func_graph) != has_visited->end()) { return lite::RET_OK; @@ -322,16 +327,15 @@ int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::setcast(); for (size_t i = 0; i < cnode->size(); ++i) { if (IsValueNode(cnode->input(i))) { - is_control_flow_ = true; auto sub_graph = GetValueNode(cnode->input(i)); MS_ASSERT(sub_graph != nullptr); - if (Process(sub_graph, has_visited) != lite::RET_OK) { + if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) { MS_LOG(ERROR) << "do subgraph const-fold failed."; return lite::RET_ERROR; } } } - if (!CheckCanFusion(cnode)) { + if (!CheckCanCommonFold(cnode)) { continue; } if (DoConstantFold(func_graph, cnode) != lite::RET_OK) { @@ -342,7 +346,7 @@ int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::setinputs(); - bool is_all_const = std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { + return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return (node->isa() && !IsValueNode(node)) || (node->isa() && node->cast()->has_default()); }); - if (is_all_const) { - return true; +} + +int ConstFoldPass::HandleSpecialFold(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) { + return lite::RET_OK; } - if (CheckPrimitiveType(cnode, prim::kPrimShape)) { - if (is_control_flow_ || lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) { + if (node_infershape_ == nullptr) { + node_infershape_ = std::make_shared(fmk_type_, train_flag_); + MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR); + } + auto manager = Manage(func_graph); + MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (!CheckCanSpecialFold(cnode)) { + continue; + } + if (DoConstantFold(func_graph, cnode) != lite::RET_OK) { + MS_LOG(ERROR) << "do constant fold failed."; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const { + MS_CHECK_TRUE_RET(cnode != nullptr, false); + for (size_t i = 0; i < cnode->size(); ++i) { + auto input_node = cnode->input(i); + MS_CHECK_TRUE_RET(input_node != nullptr, false); + if (IsValueNode(input_node)) { + return false; + } + if (!input_node->isa()) { + continue; + } + auto input_cnode = input_node->cast(); + auto input_prim = GetValueNode(input_cnode->input(0)); + MS_CHECK_TRUE_RET(input_prim != nullptr, false); + bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr || + GetValue(input_prim->GetAttr(kIsLinkWithControlFlow)); + if (is_link_with_control_flow) { return false; } - auto prim = GetValueNode(cnode->input(0)); - return prim->GetAttr(kInferDone) != nullptr && GetValue(prim->GetAttr(kInferDone)); } - return false; + auto prim = GetValueNode(cnode->input(0)); + MS_CHECK_TRUE_RET(prim != nullptr, false); + prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false)); + if (IsSpecialType(cnode)) { + return false; + } + MS_ASSERT(node_infershape_ != nullptr); + auto status = node_infershape_->InferShape(cnode); + if (CheckPrimitiveType(cnode, prim::kPrimShape)) { + return status == lite::RET_OK; + } + return CheckCanCommonFold(cnode); } int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h index ede5a0870be..cb9daf2c7ab 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h @@ -24,25 +24,29 @@ #include "include/api/context.h" #include "include/registry/parser_context.h" #include "src/inner_context.h" +#include "tools/optimizer/graph/node_infershape.h" namespace mindspore { namespace opt { class ConstFoldPass : public Pass { public: - explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs) - : Pass("ConstFoldPass"), fmk_type_(fmk_type) {} + explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false) + : Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {} ~ConstFoldPass() override = default; bool Run(const FuncGraphPtr &func_graph) override; private: bool Init(const FuncGraphPtr &func_graph); - int Process(const FuncGraphPtr &func_graph, std::set *has_visited); - bool CheckCanFusion(const CNodePtr &cnode) const; + int HandleCommonFold(const FuncGraphPtr &func_graph, std::set *has_visited); + bool CheckCanCommonFold(const CNodePtr &cnode) const; + int HandleSpecialFold(const FuncGraphPtr &func_graph); + bool CheckCanSpecialFold(const CNodePtr &cnode) const; int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; - bool is_control_flow_{false}; converter::FmkType fmk_type_{converter::kFmkTypeMs}; + bool train_flag_{false}; std::shared_ptr context_{nullptr}; std::shared_ptr ms_context_{nullptr}; + std::shared_ptr node_infershape_{nullptr}; }; } // namespace opt } // namespace mindspore