Warming up python pass by adding inline passes before it
This commit is contained in:
parent
0118930c6b
commit
1bdb26f9e8
|
@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
|||
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
|
||||
grad_op_child_scope_prefix + prim->name());
|
||||
ScopeGuard scope_guard(scope);
|
||||
py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction();
|
||||
py::function fn;
|
||||
if (prim->is_base()) {
|
||||
fn = GetBpropFunction(prim->name());
|
||||
} else {
|
||||
fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
|
||||
if (py::isinstance<py::none>(fn)) {
|
||||
fn = GetBpropFunction(prim->name());
|
||||
}
|
||||
}
|
||||
if (!fn || py::isinstance<py::none>(fn)) {
|
||||
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
|
||||
return nullptr;
|
||||
|
|
|
@ -35,8 +35,10 @@ namespace internal {
|
|||
const char PARAMETER_MODULE[] = "mindspore.common.parameter";
|
||||
const char PARAMETER_CLASS[] = "Parameter";
|
||||
const char SET_PARAM[] = "__setattr__";
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph);
|
||||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res);
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
|
||||
const FuncGraphPtr &top_graph);
|
||||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
|
||||
const MatchResultPtr &res);
|
||||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||
bool requires_grad, bool layerwise_parallel);
|
||||
|
||||
|
@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
|
|||
return std::make_shared<ValueNode>(input_tensor);
|
||||
}
|
||||
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
|
||||
const FuncGraphPtr &top_graph) {
|
||||
auto call_pattern = pattern->cast<CallPtr>();
|
||||
MS_EXCEPTION_IF_NULL(call_pattern);
|
||||
auto prim = call_pattern->prim_value();
|
||||
|
@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
|
|||
}
|
||||
auto prim_pattern = call_pattern->prim_pattern();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
return ProcessSinglePattern(prim_pattern, res, fg);
|
||||
return ProcessSinglePattern(prim_pattern, res, fg, top_graph);
|
||||
}
|
||||
|
||||
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
|
||||
auto new_para_pattern = pattern->cast<NewParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||
if (!new_para_pattern->built()) {
|
||||
static int parameter_id = 0;
|
||||
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
|
||||
auto para_node = std::make_shared<Parameter>(func_graph);
|
||||
auto para_node = std::make_shared<Parameter>(top_graph);
|
||||
MS_EXCEPTION_IF_NULL(para_node);
|
||||
para_node->set_name(para_name);
|
||||
// Set function graph
|
||||
para_node->set_func_graph(func_graph);
|
||||
para_node->set_func_graph(top_graph);
|
||||
// Set Debug Info
|
||||
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
|
||||
para_node->set_debug_info(debug_info);
|
||||
|
@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
|
|||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
para_node->set_abstract(default_value->ToAbstract()->Broaden());
|
||||
res->add_entry(pattern, para_node);
|
||||
func_graph->add_parameter(para_node);
|
||||
top_graph->add_parameter(para_node);
|
||||
// Reflect back to Cell._params
|
||||
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
||||
new_para_pattern->layerwise_parallel());
|
||||
|
@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
|||
return std::make_shared<ValueNode>(scalar_value_ptr);
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
|
||||
const FuncGraphPtr &top_graph) {
|
||||
auto target_node = res->get_node(pattern);
|
||||
if (target_node != nullptr) {
|
||||
// If pattern is NewParameter, check whether it shouldn't last and is not built
|
||||
|
@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|||
} else if (pattern->isa<NewTensor>()) {
|
||||
return BuildNewTensor(pattern, res);
|
||||
} else if (pattern->isa<Call>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
|
||||
} else if (pattern->isa<NewParameter>()) {
|
||||
return BuildNewParameter(pattern, res, func_graph);
|
||||
// Add new parameter to top graph instead of current graph
|
||||
return BuildNewParameter(pattern, res, top_graph);
|
||||
} else if (pattern->isa<Imm>()) {
|
||||
return BuildImmNode(pattern, res);
|
||||
} else {
|
||||
|
@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|||
}
|
||||
|
||||
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) {
|
||||
if (pattern->isa<Call>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
|
||||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
|
||||
const MatchResultPtr &res) {
|
||||
auto target_inputs = pattern->inputs();
|
||||
if (target_inputs.size() == 0) {
|
||||
auto new_node = ProcessSinglePattern(pattern, res, func_graph);
|
||||
auto new_node = ProcessSinglePattern(pattern, res, func_graph, top_graph);
|
||||
if (new_node != nullptr) {
|
||||
res->add_entry(pattern, new_node);
|
||||
}
|
||||
|
@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
|||
}
|
||||
// Build up the AnfNode in a recursive manner
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph);
|
||||
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph);
|
||||
MS_EXCEPTION_IF_NULL(prim_value_node);
|
||||
new_inputs.push_back(prim_value_node);
|
||||
for (auto &iter : target_inputs) {
|
||||
if (iter == pattern) {
|
||||
MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
|
||||
}
|
||||
auto input_node = BuildTarget(iter, func_graph, res);
|
||||
auto input_node = BuildTarget(iter, func_graph, top_graph, res);
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
|
||||
}
|
||||
|
@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) {
|
|||
|
||||
} // namespace internal
|
||||
|
||||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
|
||||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
|
||||
const MatchResultPtr &res) {
|
||||
auto match_res = src_pattern_->match(node);
|
||||
if (match_res != nullptr) {
|
||||
res->merge(match_res);
|
||||
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
||||
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res);
|
||||
internal::Reset(dst_pattern());
|
||||
return new_node;
|
||||
}
|
||||
|
@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
|||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
auto graph_nodes_sorted = TopoSort(func_graph->output());
|
||||
auto func_graphs = manager->func_graphs();
|
||||
bool changes = false;
|
||||
|
||||
// Traverse once
|
||||
for (auto &node : graph_nodes_sorted) {
|
||||
AnfNodePtr new_node = Run(func_graph, node, res);
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
(void)manager->Replace(node, new_node);
|
||||
changes = true;
|
||||
for (auto &fg : func_graphs) {
|
||||
manager->AddFuncGraph(fg);
|
||||
auto graph_nodes_sorted = TopoSort(fg->output());
|
||||
// Traverse once
|
||||
for (auto &node : graph_nodes_sorted) {
|
||||
AnfNodePtr new_node = Run(fg, func_graph, node, res);
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
MS_LOG(WARNING) << "Matched";
|
||||
(void)manager->Replace(node, new_node);
|
||||
changes = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changes;
|
||||
|
|
|
@ -39,7 +39,8 @@ class PythonPass {
|
|||
~PythonPass() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res);
|
||||
std::string name() const { return name_; }
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res);
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
|
||||
const MatchResultPtr &res);
|
||||
PatternPtr src_pattern() { return src_pattern_; }
|
||||
PatternPtr dst_pattern() { return dst_pattern_; }
|
||||
|
||||
|
|
|
@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
|
|||
}
|
||||
|
||||
PyPassManager::PyPassManager() {
|
||||
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
|
||||
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
|
||||
phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup");
|
||||
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup");
|
||||
res_ = std::make_shared<MatchResult>();
|
||||
}
|
||||
|
||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
bool run_only_once) {
|
||||
// NOTE: remove phase option to avoid unnecessary confusion.
|
||||
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||
bool requires_grad, bool run_only_once) {
|
||||
PassGroupPtr cur_pg;
|
||||
if (requires_grad) {
|
||||
cur_pg = GetPassGroup(Phase::PREAD);
|
||||
} else {
|
||||
cur_pg = GetPassGroup(Phase::OPT);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
cur_pg->SetRunOnlyOnce(run_only_once);
|
||||
MS_EXCEPTION_IF_NULL(pattern);
|
||||
|
@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
|
|||
}
|
||||
|
||||
void PyPassManager::Unregiste(const std::string &pass_name) {
|
||||
// NOTE: remove phase option to avoid unnecessary confusion.
|
||||
auto cur_pm = GetPassGroup(Phase::OPT);
|
||||
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||
if (!cur_pm->DeletePass(pass_name)) {
|
||||
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
|
||||
auto opt_pm = GetPassGroup(Phase::OPT);
|
||||
if (!opt_pm->DeletePass(pass_name)) {
|
||||
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
|
||||
}
|
||||
auto pre_ad_pm = GetPassGroup(Phase::PREAD);
|
||||
if (!pre_ad_pm->DeletePass(pass_name)) {
|
||||
MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,7 +98,7 @@ void PyPassManager::ClearRes() {
|
|||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
PyPassManager_, ([](const py::module *m) {
|
||||
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT);
|
||||
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
|
||||
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
|
||||
.def(py::init([]() { return PyPassManager::GetInstance(); }))
|
||||
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace python_pass {
|
|||
class PyPassManager;
|
||||
using PyPassManagerPtr = std::shared_ptr<PyPassManager>;
|
||||
|
||||
enum Phase { RESOLVE, OPT };
|
||||
enum Phase { PREAD, OPT };
|
||||
|
||||
class PyPassManager {
|
||||
protected:
|
||||
|
@ -52,8 +52,8 @@ class PyPassManager {
|
|||
// Access the only global instance
|
||||
static PyPassManagerPtr GetInstance();
|
||||
virtual ~PyPassManager() = default;
|
||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
bool run_only_once = false);
|
||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
|
||||
bool run_only_once);
|
||||
void Unregiste(const std::string &pass_name);
|
||||
void GenNewParameter(const PatternPtr ¶meter);
|
||||
PassGroupPtr GetPassGroup(Phase phase);
|
||||
|
|
|
@ -288,6 +288,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
|||
return true;
|
||||
}
|
||||
|
||||
bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }
|
||||
|
||||
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
|
||||
|
||||
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
|
||||
|
@ -460,7 +462,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
|||
return ppm->GetPassGroup(phase)->Run(res->func_graph());
|
||||
}
|
||||
|
||||
bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); }
|
||||
bool PreAdActionPyStub(const ResourcePtr &res) {
|
||||
if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) {
|
||||
MS_LOG(DEBUG) << "No Match.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool OptActionVmPyStub(const ResourcePtr &res) {
|
||||
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||
|
@ -516,12 +523,14 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
if (!multi_graphs) {
|
||||
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||
}
|
||||
// Add resolve-stage python pass stub
|
||||
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
|
||||
|
||||
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||
// Evaluate type and shape, and specialize
|
||||
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||
// Do data structure simplifications and inline
|
||||
actions.emplace_back(std::make_pair("inline", OptInlineAction));
|
||||
// Add pre-ad, post-inline python pass stub
|
||||
actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
|
||||
|
||||
return actions;
|
||||
}
|
||||
|
|
|
@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
return map_a;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
auto opt_a = GetOptPassesA(irpass);
|
||||
OptPassGroupMap a1_a2({opt_a[0], opt_a[1]});
|
||||
return a1_a2;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig c_1 = opt::OptPassConfig({
|
||||
// Safe inlining,
|
||||
|
@ -270,6 +276,7 @@ static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts =
|
|||
void InitOpt(const ResourcePtr &res) {
|
||||
if (g_pass_opts.size() == 0) {
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass));
|
||||
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
|
||||
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
|
||||
g_pass_opts["opt_after_cconv"] =
|
||||
|
@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); }
|
||||
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
|
||||
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
||||
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
|
||||
|
@ -440,5 +448,7 @@ std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
|
|||
{"cconv", CconvPass},
|
||||
{"transform_top", TransformTopGraphPass},
|
||||
{"transform_graph", OptPassTransformGraphGroup}};
|
||||
|
||||
std::vector<PassItem> kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}};
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
|
|||
|
||||
extern std::vector<PassItem> kGePasses;
|
||||
extern std::vector<PassItem> kVmPasses;
|
||||
extern std::vector<PassItem> kInlinePasses;
|
||||
extern std::vector<PassItem> kPynativePasses;
|
||||
|
||||
bool CconvPass(const ResourcePtr &res);
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Reference for python pass registration."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||
set_reopt
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
|
||||
_set_reopt
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
"_set_renorm",
|
||||
"_set_reopt"
|
||||
]
|
||||
|
|
|
@ -23,22 +23,26 @@ __all__ = [
|
|||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
"_set_renorm",
|
||||
"_set_reopt"
|
||||
]
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
"""
|
||||
def __init__(self, run_only_once=False):
|
||||
def __init__(self, requires_grad=True, run_only_once=False):
|
||||
if not isinstance(requires_grad, bool):
|
||||
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
|
||||
if not isinstance(run_only_once, bool):
|
||||
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
|
||||
self.requires_grad = requires_grad
|
||||
self.run_only_once_ = run_only_once
|
||||
PyPassManager_.__init__(self)
|
||||
|
||||
|
@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.run_only_once_)
|
||||
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
|
||||
|
||||
def unregiste(self, py_pass):
|
||||
if isinstance(py_pass, str):
|
||||
|
@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
|
||||
super().set_reopt(do_reopt)
|
||||
|
||||
def registe_pass(run_only_once=False):
|
||||
def registe_pass(requires_grad=True, run_only_once=False):
|
||||
"""
|
||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||
|
||||
Args:
|
||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
||||
|
||||
Returns:
|
||||
|
@ -99,7 +104,7 @@ def registe_pass(run_only_once=False):
|
|||
>>> target = IsPrimTypeOf("ReLU6")
|
||||
>>> return pattern, target
|
||||
"""
|
||||
return PyPassManager(run_only_once)
|
||||
return PyPassManager(requires_grad, run_only_once)
|
||||
|
||||
def unregiste_pass(py_pass):
|
||||
"""
|
||||
|
@ -157,7 +162,7 @@ def cancel_new_parameter(pattern):
|
|||
ppm = PyPassManager()
|
||||
ppm.unregiste(pattern.para_name)
|
||||
|
||||
def set_renorm(should_renorm):
|
||||
def _set_renorm(should_renorm):
|
||||
"""
|
||||
Set whether or not to do renormalization after modified graph in python pass(es).
|
||||
|
||||
|
@ -171,7 +176,7 @@ def set_renorm(should_renorm):
|
|||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
||||
|
||||
def set_reopt(do_reopt):
|
||||
def _set_reopt(do_reopt):
|
||||
"""
|
||||
Set whether or not to do optimization after modified graph in python pass(es).
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ import mindspore.nn as nn
|
|||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter, set_reopt
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter, _set_reopt
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
||||
|
@ -157,8 +157,8 @@ def test_isnot_pattern_0():
|
|||
Test IsNot pattern which expresses the IsNot semantics.
|
||||
Case: IsNot pass failed to match
|
||||
"""
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
class ConvBN(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ConvBN, self).__init__()
|
||||
|
@ -176,7 +176,7 @@ def test_isnot_pattern_0():
|
|||
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
|
||||
conv_bn_model = ConvBN()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def single_bn_pass():
|
||||
"""
|
||||
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
||||
|
@ -188,7 +188,7 @@ def test_isnot_pattern_0():
|
|||
target = Call(P.ReLU6(), [pattern_0])
|
||||
return pattern, target
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def bn_pass():
|
||||
"""
|
||||
Sub a BN to Softmax.
|
||||
|
@ -202,7 +202,7 @@ def test_isnot_pattern_0():
|
|||
unregiste_pass(bn_pass)
|
||||
assert "ReLU6" not in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
set_renorm(True)
|
||||
_set_renorm(True)
|
||||
|
||||
def test_isnot_pattern_1():
|
||||
"""
|
||||
|
@ -234,12 +234,12 @@ def test_newtensor_pattern():
|
|||
"""
|
||||
Test NewTensor pattern in the target
|
||||
"""
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -252,7 +252,7 @@ def test_newtensor_pattern():
|
|||
unregiste_pass(softmax_addn_pass)
|
||||
assert "AddN" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
set_renorm(True)
|
||||
_set_renorm(True)
|
||||
|
||||
def test_newparameter_pattern():
|
||||
"""
|
||||
|
@ -261,9 +261,9 @@ def test_newparameter_pattern():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(run_only_once=True)
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -288,9 +288,9 @@ def test_imm_target():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(run_only_once=True)
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -313,10 +313,10 @@ def test_gen_new_parameter():
|
|||
|
||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para = NewParameter("Merlin", default_tensor)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
gen_new_parameter(new_para)
|
||||
@registe_pass(run_only_once=True)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_make_tuple_pass():
|
||||
x = Any()
|
||||
softmax = P.Softmax()
|
||||
|
|
Loading…
Reference in New Issue