forked from mindspore-Ecosystem/mindspore
!24425 [lite]strengthen constant fold
Merge pull request !24425 from 徐安越/master
This commit is contained in:
commit
157751c8ce
|
@ -190,7 +190,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk, config->trainModel));
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
|
||||
if (config->fmk == converter::kFmkTypeMs && !config->trainModel) {
|
||||
|
@ -329,7 +329,7 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
|
|||
CHECK_NULL_RETURN(const_fold_pm);
|
||||
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
|
||||
if (!config->trainModel) {
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk, config->trainModel));
|
||||
}
|
||||
const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
|
||||
|
@ -511,9 +511,9 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
|
|||
auto is_train = config->trainModel;
|
||||
std::unordered_map<std::string, opt::PassPtr> passes = {
|
||||
{"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
|
||||
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(config->fmk)},
|
||||
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
|
||||
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
|
||||
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train)},
|
||||
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
|
||||
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()},
|
||||
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>()},
|
||||
|
|
|
@ -39,6 +39,7 @@ using mindspore::lite::Tensor;
|
|||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
constexpr auto kIsLinkWithControlFlow = "link_with_control_flow";
|
||||
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
|
||||
if (input_tensor != nullptr) {
|
||||
for (auto &i : *input_tensor) {
|
||||
|
@ -283,7 +284,11 @@ bool ConstFoldPass::Run(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
std::set<FuncGraphPtr> has_visited;
|
||||
if (Process(func_graph, &has_visited) != lite::RET_OK) {
|
||||
if (HandleCommonFold(func_graph, &has_visited) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold pass failed,";
|
||||
return false;
|
||||
}
|
||||
if (HandleSpecialFold(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold pass failed,";
|
||||
return false;
|
||||
}
|
||||
|
@ -306,7 +311,7 @@ bool ConstFoldPass::Init(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
|
||||
int ConstFoldPass::HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (has_visited->find(func_graph) != has_visited->end()) {
|
||||
return lite::RET_OK;
|
||||
|
@ -322,16 +327,15 @@ int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
if (IsValueNode<FuncGraph>(cnode->input(i))) {
|
||||
is_control_flow_ = true;
|
||||
auto sub_graph = GetValueNode<FuncGraphPtr>(cnode->input(i));
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
if (Process(sub_graph, has_visited) != lite::RET_OK) {
|
||||
if (HandleCommonFold(sub_graph, has_visited) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do subgraph const-fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!CheckCanFusion(cnode)) {
|
||||
if (!CheckCanCommonFold(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
|
||||
|
@ -342,7 +346,7 @@ int ConstFoldPass::Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConstFoldPass::CheckCanFusion(const CNodePtr &cnode) const {
|
||||
bool ConstFoldPass::CheckCanCommonFold(const CNodePtr &cnode) const {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||
if (IsSpecialType(cnode)) {
|
||||
return false;
|
||||
|
@ -351,21 +355,72 @@ bool ConstFoldPass::CheckCanFusion(const CNodePtr &cnode) const {
|
|||
return false;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
bool is_all_const = std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
|
||||
return std::all_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) {
|
||||
return (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) ||
|
||||
(node->isa<Parameter>() && node->cast<ParameterPtr>()->has_default());
|
||||
});
|
||||
if (is_all_const) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int ConstFoldPass::HandleSpecialFold(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
if (lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
|
||||
if (is_control_flow_ || lite::ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() == 0) {
|
||||
if (node_infershape_ == nullptr) {
|
||||
node_infershape_ = std::make_shared<NodeInferShape>(fmk_type_, train_flag_);
|
||||
MS_CHECK_TRUE_RET(node_infershape_ != nullptr, lite::RET_ERROR);
|
||||
}
|
||||
auto manager = Manage(func_graph);
|
||||
MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!CheckCanSpecialFold(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (DoConstantFold(func_graph, cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "do constant fold failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, false);
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
auto input_node = cnode->input(i);
|
||||
MS_CHECK_TRUE_RET(input_node != nullptr, false);
|
||||
if (IsValueNode<FuncGraph>(input_node)) {
|
||||
return false;
|
||||
}
|
||||
if (!input_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(input_prim != nullptr, false);
|
||||
bool is_link_with_control_flow = input_prim->GetAttr(kIsLinkWithControlFlow) == nullptr ||
|
||||
GetValue<bool>(input_prim->GetAttr(kIsLinkWithControlFlow));
|
||||
if (is_link_with_control_flow) {
|
||||
return false;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
return prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(prim->GetAttr(kInferDone));
|
||||
}
|
||||
return false;
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, false);
|
||||
prim->AddAttr(kIsLinkWithControlFlow, MakeValue(false));
|
||||
if (IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
MS_ASSERT(node_infershape_ != nullptr);
|
||||
auto status = node_infershape_->InferShape(cnode);
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimShape)) {
|
||||
return status == lite::RET_OK;
|
||||
}
|
||||
return CheckCanCommonFold(cnode);
|
||||
}
|
||||
|
||||
int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
|
|
|
@ -24,25 +24,29 @@
|
|||
#include "include/api/context.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "tools/optimizer/graph/node_infershape.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConstFoldPass : public Pass {
|
||||
public:
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs)
|
||||
: Pass("ConstFoldPass"), fmk_type_(fmk_type) {}
|
||||
explicit ConstFoldPass(converter::FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: Pass("ConstFoldPass"), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ConstFoldPass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Init(const FuncGraphPtr &func_graph);
|
||||
int Process(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited);
|
||||
bool CheckCanFusion(const CNodePtr &cnode) const;
|
||||
int HandleCommonFold(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *has_visited);
|
||||
bool CheckCanCommonFold(const CNodePtr &cnode) const;
|
||||
int HandleSpecialFold(const FuncGraphPtr &func_graph);
|
||||
bool CheckCanSpecialFold(const CNodePtr &cnode) const;
|
||||
int DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
bool is_control_flow_{false};
|
||||
converter::FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
std::shared_ptr<lite::InnerContext> context_{nullptr};
|
||||
std::shared_ptr<mindspore::Context> ms_context_{nullptr};
|
||||
std::shared_ptr<NodeInferShape> node_infershape_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue