!19391 add auto-monad-elim interface

Merge pull request !19391 from huangbingjian/monad_elim
This commit is contained in:
i-robot 2021-07-06 04:04:31 +00:00 committed by Gitee
commit 816b416122
2 changed files with 23 additions and 0 deletions

View File

@ -27,6 +27,7 @@
#include "pipeline/jit/resource.h"
#include "pipeline/jit/validator.h"
#include "pipeline/jit/remove_value_node_dup.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/cse_pass.h"
#include "frontend/optimizer/clean.h"
@ -45,6 +46,7 @@
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if (ENABLE_CPU && !_WIN32)
#include "ps/util.h"
#include "ps/ps_context.h"
@ -643,6 +645,26 @@ bool PynativeOptPass(const ResourcePtr &res) {
return true;
}
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(func_graph->manager());
auto res = std::make_shared<pipeline::Resource>();
res->set_func_graph(func_graph);
res->set_manager(func_graph->manager());
// opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
opt::SubstitutionPtr updatestate_eliminater = opt::MakeSubstitution(
std::make_shared<opt::irpass::UpdatestateEliminater>(), "updatestate_eliminater", prim::kPrimUpdateState);
opt::OptPassGroupMap elim_map({
{"updatestate_eliminate", opt::OptPassConfig({updatestate_eliminater})},
{"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
});
auto auto_monad_elim_opt = opt::Optimizer::MakeOptimizer("auto_monad_elim", res, elim_map);
(void)auto_monad_elim_opt->step(func_graph, false);
return true;
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_before_recompute", OptBeforeRecomputeGroup},
{"opt_a", OptPassAGroup},

View File

@ -46,6 +46,7 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res);
bool InferenceOptPreparePass(const ResourcePtr &res);
void ReclaimOptimizer();
bool PynativeOptPass(const ResourcePtr &res);
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph);
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res);
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res);