From 64d9d7632185df4983269ab28ff98525718a6604 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Thu, 24 Dec 2020 21:13:05 +0800 Subject: [PATCH] Add grad epilogue passes for PyNative mode. --- mindspore/ccsrc/pipeline/jit/pass.cc | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 817d3634708..78d9c399efb 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); + // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases(). OptPassGroupMap map_a({{"a_1", a_1}, {"a_2", a_2}, {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, @@ -272,6 +273,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { return map; } +OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) { + auto opt_a = GetOptPassesA(irpass); + auto a3 = opt_a[opt_a.size() - 1]; + OptPassGroupMap map({ + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSEPass(false))}, + {a3}, + }); + return map; +} + OptPassGroupMap GetInferenceOptPreparePhases() { opt::irpass::InferenceOptPrepareLib irpass; auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); @@ -303,6 +315,8 @@ void InitOpt(const ResourcePtr &res) { Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); + g_pass_opts["opt_grad_epilogue"] = + Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -351,6 +365,8 @@ bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepar bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } +bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); } + bool AddControlDependPass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -469,7 +485,8 @@ std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStru {"opt_prepare", PrepareGroup}, {"cconv", CconvPass}}; -std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, +std::vector kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup}, + {"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, {"transform_top", TransformTopGraphPass},