fix code check

This commit is contained in:
huanghui 2021-05-17 10:32:37 +08:00
parent 85ba75565b
commit 53e32077c1
6 changed files with 84 additions and 68 deletions

View File

@ -48,8 +48,8 @@ PyPassManager::PyPassManager() {
res_ = std::make_shared<MatchResult>();
}
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool requires_grad, bool run_only_once) {
void PyPassManager::Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool requires_grad, bool run_only_once) {
PassGroupPtr cur_pg;
if (requires_grad) {
cur_pg = GetPassGroup(Phase::PREAD);
@ -65,7 +65,7 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
cur_pg->AddPass(new_pass);
}
void PyPassManager::Unregiste(const std::string &pass_name) {
void PyPassManager::Unregister(const std::string &pass_name) {
auto opt_pm = GetPassGroup(Phase::OPT);
if (!opt_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
@ -101,8 +101,8 @@ REGISTER_PYBIND_DEFINE(
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass")
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
.def("register", &PyPassManager::Register, "Register python pass")
.def("unregister", &PyPassManager::Unregister, "Unregister Python Pass")
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");

View File

@ -52,9 +52,9 @@ class PyPassManager {
// Access the only global instance
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
bool run_only_once);
void Unregiste(const std::string &pass_name);
void Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
bool run_only_once);
void Unregister(const std::string &pass_name);
void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase);
MatchResultPtr GetMatchResult() { return res_; }

View File

@ -28,6 +28,7 @@ __all__ = [
"Imm"
]
class OneOf(OneOf_):
r"""
Express a pattern which allows a list of patterns.
@ -51,6 +52,7 @@ class OneOf(OneOf_):
else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_):
r"""
Express a pattern of certain primitive type(s).
@ -95,6 +97,7 @@ class Prim(Prim_):
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
Prim_.__init__(self, self.types, self.name)
class Call(Call_):
r"""
Express a primitive CNode.
@ -124,6 +127,7 @@ class Call(Call_):
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
Call_.__init__(self, self.prim_pattern, self.inputs)
class NoneOf(NoneOf_):
r"""
Express a pattern which forbids a list of patterns.
@ -134,7 +138,7 @@ class NoneOf(NoneOf_):
def __init__(self, patterns=None):
r"""
Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element
should be one of the exposed Pattern instance.
Raises:
@ -150,6 +154,7 @@ class NoneOf(NoneOf_):
else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_):
r"""
New Tensor to be used in the target.
@ -157,7 +162,7 @@ class NewTensor(NewTensor_):
def __init__(self, input_tensor):
r"""
Args:
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target.
Raises:
TypeError: raise type error for invalid argument.
@ -168,6 +173,7 @@ class NewTensor(NewTensor_):
else:
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
@ -175,10 +181,10 @@ class NewParameter(NewParameter_):
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
r"""
Args:
para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
para_name(str): name for the new Parameter.
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter.
requires_grad(bool): True if the parameter requires gradient. Default: True.
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False.
Raises:
TypeError: raise type error for invalid argument.

View File

@ -13,12 +13,12 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
from .python_pass_register import register_pass, unregister_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt
__all__ = [
"registe_pass",
"unregiste_pass",
"register_pass",
"unregister_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm",

View File

@ -17,18 +17,19 @@ from inspect import isfunction
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
from mindspore._c_expression import PyPassManager_
__all__ = [
"registe_pass",
"unregiste_pass",
"register_pass",
"unregister_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm",
"set_reopt"
]
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.
Used to register and unregister python passes which can be used to alter graphs.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
@ -46,7 +47,7 @@ class PyPassManager(PyPassManager_):
self.run_only_once_ = run_only_once
PyPassManager_.__init__(self)
def registe(self, py_pass):
def register(self, py_pass):
if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
@ -55,19 +56,19 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregiste(self, py_pass):
def unregister(self, py_pass):
if isinstance(py_pass, str):
super().unregiste(py_pass)
super().unregister(py_pass)
return
if isfunction(py_pass):
super().unregiste(py_pass.__name__)
super().unregister(py_pass.__name__)
return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
self.registe(py_pass)
self.register(py_pass)
return py_pass
def gen_new_parameter(self, pattern):
@ -85,36 +86,42 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)
def registe_pass(requires_grad=True, run_only_once=False):
def register_pass(requires_grad=True, run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Register python pass to specified pipeline phase which would be used in compilation.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default:
False.
Returns:
This function should be used as a decorator, return the decoratorated pass function.
Examples:
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
>>> @registe_pass()
>>> from mindspore.graph_utils.graph_pattern import Call, Any
>>> from mindspore.ops import operations as P
>>> @register_pass()
>>> def toy_pass():
>>> pattern = IsPrimTypeOf("ReLU")
>>> target = IsPrimTypeOf("ReLU6")
>>> x = Any()
>>> pattern = Call(P.Softmax(), [x])
>>> target = Call(P.ReLU(), [x])
>>> return pattern, target
"""
return PyPassManager(requires_grad, run_only_once)
def unregiste_pass(py_pass):
def unregister_pass(py_pass):
"""
Unregiste python pass.
Unregister python pass.
Args:
py_pass(Union(str, function)): target python pass to unregiste.
py_pass(Union(str, function)): target python pass to unregister.
"""
ppm = PyPassManager()
ppm.unregiste(py_pass)
ppm.unregister(py_pass)
def gen_new_parameter(pattern):
"""
@ -123,8 +130,8 @@ def gen_new_parameter(pattern):
NOTE:
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
gen_new_parameter, every pass match would build a new Parameter.
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
This would register a pass to add new parameter in the compilation pipeline, so later compilation would
ALSO add this parameter unless the pass is unregistered. To unregister this pass, call
cancel_new_parameter(pattern)
Args:
@ -142,9 +149,10 @@ def gen_new_parameter(pattern):
ppm = PyPassManager()
ppm.gen_new_parameter(pattern)
def cancel_new_parameter(pattern):
"""
Use with gen_new_parameter to unregiste gen_new_parameter pass.
Use with gen_new_parameter to unregister gen_new_parameter pass.
Args:
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
@ -160,7 +168,8 @@ def cancel_new_parameter(pattern):
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)
ppm.unregister(pattern.para_name)
def set_renorm(should_renorm):
"""
@ -176,6 +185,7 @@ def set_renorm(should_renorm):
ppm = PyPassManager()
ppm.set_renorm(should_renorm)
def set_reopt(do_reopt):
"""
Set whether or not to do optimization after modified graph in python pass(es).

View File

@ -20,7 +20,7 @@ from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
from mindspore.graph_utils.python_pass import register_pass, unregister_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
@ -48,7 +48,7 @@ def test_softmax_relu():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@register_pass(run_only_once=True)
def softmax_relu_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@ -56,7 +56,7 @@ def test_softmax_relu():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass)
unregister_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr
@ -64,7 +64,7 @@ def test_prim():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@register_pass(run_only_once=True)
def softmax_relu_pass():
x = Any()
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
@ -73,7 +73,7 @@ def test_prim():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass)
unregister_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr
@ -87,7 +87,7 @@ def test_softmax_relu_sigmoid():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@register_pass(run_only_once=True)
def softmax_relu_pass():
x = Any()
softmax_pattern = Prim(P.Softmax())
@ -99,7 +99,7 @@ def test_softmax_relu_sigmoid():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass)
unregister_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Sigmoid" in transformed_repr
assert "Softmax" not in transformed_repr
@ -112,7 +112,7 @@ def test_isin_pattern_0():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@register_pass(run_only_once=True)
def softmax_relu_pass():
x = Any()
softmax_pattern = Prim(P.Softmax())
@ -125,7 +125,7 @@ def test_isin_pattern_0():
target = Call(relu6_pattern, [x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass)
unregister_pass(softmax_relu_pass)
assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr
@ -136,7 +136,7 @@ def test_isin_pattern_1():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@register_pass(run_only_once=True)
def softmax_neg_pass():
x = Any()
softmax_pattern = Prim(P.Softmax())
@ -149,7 +149,7 @@ def test_isin_pattern_1():
target = Call(neg_ops, [pattern])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
unregiste_pass(softmax_neg_pass)
unregister_pass(softmax_neg_pass)
assert "Neg" in transformed_repr
assert "Softmax" in transformed_repr
@ -177,7 +177,7 @@ def test_isnot_pattern_0():
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
conv_bn_model = ConvBN()
@registe_pass(requires_grad=False, run_only_once=True)
@register_pass(requires_grad=False, run_only_once=True)
def single_bn_pass():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
@ -189,7 +189,7 @@ def test_isnot_pattern_0():
target = Call(P.ReLU6(), [pattern_0])
return pattern, target
@registe_pass(requires_grad=False, run_only_once=True)
@register_pass(requires_grad=False, run_only_once=True)
def bn_pass():
"""
Sub a BN to Softmax.
@ -199,8 +199,8 @@ def test_isnot_pattern_0():
return pattern, target
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
unregiste_pass(single_bn_pass)
unregiste_pass(bn_pass)
unregister_pass(single_bn_pass)
unregister_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
set_renorm(True)
@ -213,7 +213,7 @@ def test_isnot_pattern_1():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@register_pass(run_only_once=True)
def single_bn_pass():
"""
Sub a BN which does NOT take MatMul as inputs to ReLU6.
@ -227,7 +227,7 @@ def test_isnot_pattern_1():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(single_bn_pass)
unregister_pass(single_bn_pass)
assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr
@ -240,7 +240,7 @@ def test_newtensor_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(requires_grad=False, run_only_once=True)
@register_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@ -250,7 +250,7 @@ def test_newtensor_pattern():
target = Call(P.AddN(), [x, new_weight])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_addn_pass)
unregister_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
set_renorm(True)
@ -264,7 +264,7 @@ def test_newparameter_pattern():
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
@register_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@ -277,7 +277,7 @@ def test_newparameter_pattern():
target = Call("MakeTuple", [target_0])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_addn_pass)
unregister_pass(softmax_addn_pass)
assert "MatMul" in transformed_repr
assert "MakeTuple" in transformed_repr
assert "Softmax" not in transformed_repr
@ -291,7 +291,7 @@ def test_imm_target():
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
@register_pass(requires_grad=False, run_only_once=True)
def softmax_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@ -300,7 +300,7 @@ def test_imm_target():
target = Call(Constants.kTupleGetItem, [target_0, imm])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_pass)
unregister_pass(softmax_pass)
assert "MakeTuple" in transformed_repr
assert Constants.kTupleGetItem in transformed_repr
assert "Softmax" in transformed_repr
@ -317,7 +317,7 @@ def test_gen_new_parameter():
set_renorm(False)
set_reopt(False)
gen_new_parameter(new_para)
@registe_pass(requires_grad=False, run_only_once=True)
@register_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass():
x = Any()
softmax = P.Softmax()
@ -327,7 +327,7 @@ def test_gen_new_parameter():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
assert "Merlin" in transformed_repr
unregiste_pass(softmax_make_tuple_pass)
unregister_pass(softmax_make_tuple_pass)
cancel_new_parameter(new_para)
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
assert "Merlin" not in transformed_repr