!24425 [lite]strengthen constant fold

Merge pull request !24425 from 徐安越/master
This commit is contained in:
i-robot 2021-10-11 08:59:34 +00:00 committed by Gitee
commit 157751c8ce
3 changed files with 81 additions and 22 deletions

View File

@ -190,7 +190,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk, config->trainModel));
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
if (config->fmk == converter::kFmkTypeMs && !config->trainModel) {
@ -329,7 +329,7 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
CHECK_NULL_RETURN(const_fold_pm);
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
if (!config->trainModel) {
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk, config->trainModel));
}
const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
@ -511,9 +511,9 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
auto is_train = config->trainModel;
std::unordered_map<std::string, opt::PassPtr> passes = {
{"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(config->fmk)},
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train)},
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()},
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>()},

View File

@ -39,6 +39,7 @@ using mindspore::lite::Tensor;
namespace mindspore::opt {
namespace {
constexpr size_t INITIAL_SIZE = 1024;
constexpr auto kIsLinkWithControlFlow = "link_with_control_flow";
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
if (input_tensor != nullptr) {
for (auto &i : *input_tensor) {
@ -283,7 +284,11 @@ bool ConstFoldPass::Run(const FuncGraphPtr &func_graph) {
return false;
}
std::set<FuncGraphPtr> has_visited;
if (Process(func_graph, &has_visited) != lite::RET_OK) {
if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold pass failed,";
return false;
}
if (HandleSpecialFold(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold pass failed,";
return false;
}
@ -306,7 +311,7 @@ bool ConstFoldPass::Init(const FuncGraphPtr &func_graph) {
return true;
}
int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
int ConstFoldPass::HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
MS_ASSERT(func_graph != nullptr);
if (has_visited->find(func_graph) != has_visited->end()) {
return lite::RET_OK;
@ -322,16 +327,15 @@ int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr
auto cnode = node->cast<CNodePtr>();
for (size_t i = 0; i < cnode->size(); ++i) {
if (IsValueNode<FuncGraph>(cnode->input(i))) {
is_control_flow_ = true;
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(i));
MS_ASSERT(sub_graph != nullptr);
if (Process(sub_graph, has_visited) != lite::RET_OK) {
if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) {
MS_LOG(ERROR) << "do subgraph const-fold failed.";
return lite::RET_ERROR;
}
}
}
if (!CheckCanFusion(cnode)) {
if (!CheckCanCommonFold(cnode)) {
continue;
}
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
@ -342,7 +346,7 @@ int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr
return lite::RET_OK;
}
bool ConstFoldPass::CheckCanFusion(const CNodePtr &cnode) const {
bool ConstFoldPass::CheckCanCommonFold(const CNodePtr &cnode) const {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
if (IsSpecialType(cnode)) {
return false;
@ -351,21 +355,72 @@ bool ConstFoldPass::CheckCanFusion(const CNodePtr &cnode) const {
return false;
}
auto inputs = cnode->inputs();
bool is_all_const = std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default());
});
if (is_all_const) {
return true;
}
int ConstFoldPass::HandleSpecialFold(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
return lite::RET_OK;
}
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
if (is_control_flow_ || lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
if (node_infershape_ == nullptr) {
node_infershape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR);
}
auto manager = Manage(func_graph);
MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!CheckCanSpecialFold(cnode)) {
continue;
}
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "do constant fold failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const {
MS_CHECK_TRUE_RET(cnode != nullptr, false);
for (size_t i = 0; i < cnode->size(); ++i) {
auto input_node = cnode->input(i);
MS_CHECK_TRUE_RET(input_node != nullptr, false);
if (IsValueNode<FuncGraph>(input_node)) {
return false;
}
if (!input_node->isa<CNode>()) {
continue;
}
auto input_cnode = input_node->cast<CNodePtr>();
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_CHECK_TRUE_RET(input_prim != nullptr, false);
bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr ||
GetValue<bool>(input_prim->GetAttr(kIsLinkWithControlFlow));
if (is_link_with_control_flow) {
return false;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
return prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(prim->GetAttr(kInferDone));
}
return false;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_RET(prim != nullptr, false);
prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false));
if (IsSpecialType(cnode)) {
return false;
}
MS_ASSERT(node_infershape_ != nullptr);
auto status = node_infershape_->InferShape(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
return status == lite::RET_OK;
}
return CheckCanCommonFold(cnode);
}
int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {

View File

@ -24,25 +24,29 @@
#include "include/api/context.h"
#include "include/registry/parser_context.h"
#include "src/inner_context.h"
#include "tools/optimizer/graph/node_infershape.h"
namespace mindspore {
namespace opt {
class ConstFoldPass : public Pass {
public:
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs)
: Pass("ConstFoldPass"), fmk_type_(fmk_type) {}
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ConstFoldPass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Init(const FuncGraphPtr &func_graph);
int Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited);
bool CheckCanFusion(const CNodePtr &cnode) const;
int HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited);
bool CheckCanCommonFold(const CNodePtr &cnode) const;
int HandleSpecialFold(const FuncGraphPtr &func_graph);
bool CheckCanSpecialFold(const CNodePtr &cnode) const;
int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
bool is_control_flow_{false};
converter::FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
std::shared_ptr<lite::InnerContext> context_{nullptr};
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
std::shared_ptr<NodeInferShape> node_infershape_{nullptr};
};
} // namespace opt
} // namespace mindspore