!29886 Get recomputed cell outputs after opt a1

Merge pull request !29886 from YuJianfeng/recompute
This commit is contained in:
i-robot 2022-02-22 09:26:06 +00:00 committed by Gitee
commit 8daae1481c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 47 additions and 14 deletions

View File

@ -64,21 +64,62 @@ class SetCellOutputNoRecompute : public AnfVisitor {
void GetRealOutputNodes(const AnfNodePtr &output, mindspore::HashSet<CNodePtr> *real_outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(real_outputs);
if (!output->isa<CNode>()) {
auto output_cnode = output->cast<CNodePtr>();
if (output_cnode == nullptr) {
return;
}
auto output_cnode = output->cast<CNodePtr>();
if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) {
auto input0 = output_cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimTupleGetItem)) {
GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
} else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) {
} else if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
GetRealOutputNodes(output_cnode->input(i), real_outputs);
}
} else if (IsValueNode<FuncGraph>(input0)) {
auto fg = GetValueNode<FuncGraphPtr>(input0);
GetRealOutputNodes(fg->output(), real_outputs);
} else if (input0->isa<CNode>()) {
auto abs = input0->abstract();
if (abs == nullptr || !abs->isa<abstract::AbstractFunction>()) {
return;
}
auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
if (abs_func->isa<abstract::AbstractFuncUnion>()) {
auto visit_fn = [this, &real_outputs](const abstract::AbstractFuncAtomPtr &poss) {
auto abs_fg = GetAbstractFuncGraph(poss);
if (abs_fg != nullptr) {
GetRealOutputNodes(abs_fg->output(), real_outputs);
}
};
abs_func->Visit(visit_fn);
return;
}
auto fg = GetAbstractFuncGraph(abs_func);
if (fg != nullptr) {
GetRealOutputNodes(fg->output(), real_outputs);
}
} else {
real_outputs->insert(output_cnode);
}
}
FuncGraphPtr GetAbstractFuncGraph(const abstract::AbstractFunctionPtr &abs) {
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
auto abstract_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
return abstract_func_graph->func_graph();
}
if (abs->isa<abstract::PartialAbstractClosure>()) {
auto abstract_partial_func = abs->cast<abstract::PartialAbstractClosurePtr>();
auto abstract_fn = abstract_partial_func->fn();
if (abstract_fn != nullptr && abstract_fn->isa<abstract::FuncGraphAbstractClosure>()) {
auto abstract_func_graph = abstract_fn->cast<abstract::FuncGraphAbstractClosurePtr>();
return abstract_func_graph->func_graph();
}
}
return nullptr;
}
};
} // namespace irpass
} // namespace opt

View File

@ -349,10 +349,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
opt::OptPassConfig recompute_prepare = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
{"a_1", a_1},
{"recompute_prepare", recompute_prepare},
{"updatestate_depend_eliminate", updatestate_depend_eliminate},
{"updatestate_assign_eliminate", updatestate_assign_eliminate},
{"updatestate_loads_eliminate", updatestate_loads_eliminate},
@ -536,12 +538,6 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
return map;
}
OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
return map;
}
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
return map;
@ -564,8 +560,6 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_grad_epilogue"] =
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_before_recompute"] =
Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
}
@ -605,7 +599,6 @@ bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
@ -775,7 +768,6 @@ bool EnvironConversionPass(const ResourcePtr &res) {
}
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_before_recompute", OptBeforeRecomputeGroup},
{"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup},