forked from mindspore-Ecosystem/mindspore
!29886 Get recomputed cell outputs after opt a1
Merge pull request !29886 from YuJianfeng/recompute
This commit is contained in:
commit
8daae1481c
|
@ -64,21 +64,62 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
|||
void GetRealOutputNodes(const AnfNodePtr &output, mindspore::HashSet<CNodePtr> *real_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_EXCEPTION_IF_NULL(real_outputs);
|
||||
if (!output->isa<CNode>()) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
if (output_cnode == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(output_cnode, prim::kPrimDepend) || IsPrimitiveCNode(output_cnode, prim::kPrimTupleGetItem)) {
|
||||
auto input0 = output_cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||
GetRealOutputNodes(output_cnode->input(kRealInputIndexInDepend), real_outputs);
|
||||
} else if (IsPrimitiveCNode(output_cnode, prim::kPrimMakeTuple)) {
|
||||
} else if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
auto &inputs = output_cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
GetRealOutputNodes(output_cnode->input(i), real_outputs);
|
||||
}
|
||||
} else if (IsValueNode<FuncGraph>(input0)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(input0);
|
||||
GetRealOutputNodes(fg->output(), real_outputs);
|
||||
} else if (input0->isa<CNode>()) {
|
||||
auto abs = input0->abstract();
|
||||
if (abs == nullptr || !abs->isa<abstract::AbstractFunction>()) {
|
||||
return;
|
||||
}
|
||||
auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
|
||||
if (abs_func->isa<abstract::AbstractFuncUnion>()) {
|
||||
auto visit_fn = [this, &real_outputs](const abstract::AbstractFuncAtomPtr &poss) {
|
||||
auto abs_fg = GetAbstractFuncGraph(poss);
|
||||
if (abs_fg != nullptr) {
|
||||
GetRealOutputNodes(abs_fg->output(), real_outputs);
|
||||
}
|
||||
};
|
||||
abs_func->Visit(visit_fn);
|
||||
return;
|
||||
}
|
||||
auto fg = GetAbstractFuncGraph(abs_func);
|
||||
if (fg != nullptr) {
|
||||
GetRealOutputNodes(fg->output(), real_outputs);
|
||||
}
|
||||
} else {
|
||||
real_outputs->insert(output_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr GetAbstractFuncGraph(const abstract::AbstractFunctionPtr &abs) {
|
||||
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto abstract_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||
return abstract_func_graph->func_graph();
|
||||
}
|
||||
if (abs->isa<abstract::PartialAbstractClosure>()) {
|
||||
auto abstract_partial_func = abs->cast<abstract::PartialAbstractClosurePtr>();
|
||||
auto abstract_fn = abstract_partial_func->fn();
|
||||
if (abstract_fn != nullptr && abstract_fn->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto abstract_func_graph = abstract_fn->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||
return abstract_func_graph->func_graph();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -349,10 +349,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
|
||||
opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
|
||||
opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
|
||||
opt::OptPassConfig recompute_prepare = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
|
||||
|
||||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||
OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
|
||||
{"a_1", a_1},
|
||||
{"recompute_prepare", recompute_prepare},
|
||||
{"updatestate_depend_eliminate", updatestate_depend_eliminate},
|
||||
{"updatestate_assign_eliminate", updatestate_assign_eliminate},
|
||||
{"updatestate_loads_eliminate", updatestate_loads_eliminate},
|
||||
|
@ -536,12 +538,6 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
|
||||
OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
|
||||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
|
||||
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
|
||||
return map;
|
||||
|
@ -564,8 +560,6 @@ void InitOpt(const ResourcePtr &res) {
|
|||
g_pass_opts["opt_grad_epilogue"] =
|
||||
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
|
||||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||
g_pass_opts["opt_before_recompute"] =
|
||||
Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
|
||||
g_pass_opts["opt_after_recompute"] =
|
||||
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
||||
}
|
||||
|
@ -605,7 +599,6 @@ bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "
|
|||
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
||||
bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
|
||||
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
||||
|
||||
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
|
||||
|
@ -775,7 +768,6 @@ bool EnvironConversionPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_before_recompute", OptBeforeRecomputeGroup},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
|
|
Loading…
Reference in New Issue