fix code check

This commit is contained in:
huanghui 2021-05-17 10:32:37 +08:00
parent 85ba75565b
commit 53e32077c1
6 changed files with 84 additions and 68 deletions

View File

@ -48,8 +48,8 @@ PyPassManager::PyPassManager() {
res_ = std::make_shared<MatchResult>(); res_ = std::make_shared<MatchResult>();
} }
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, void PyPassManager::Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool requires_grad, bool run_only_once) { bool requires_grad, bool run_only_once) {
PassGroupPtr cur_pg; PassGroupPtr cur_pg;
if (requires_grad) { if (requires_grad) {
cur_pg = GetPassGroup(Phase::PREAD); cur_pg = GetPassGroup(Phase::PREAD);
@ -65,7 +65,7 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
cur_pg->AddPass(new_pass); cur_pg->AddPass(new_pass);
} }
void PyPassManager::Unregiste(const std::string &pass_name) { void PyPassManager::Unregister(const std::string &pass_name) {
auto opt_pm = GetPassGroup(Phase::OPT); auto opt_pm = GetPassGroup(Phase::OPT);
if (!opt_pm->DeletePass(pass_name)) { if (!opt_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n"; MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
@ -101,8 +101,8 @@ REGISTER_PYBIND_DEFINE(
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT); (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_") (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); })) .def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass") .def("register", &PyPassManager::Register, "Register python pass")
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") .def("unregister", &PyPassManager::Unregister, "Unregister Python Pass")
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph") .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph"); .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");

View File

@ -52,9 +52,9 @@ class PyPassManager {
// Access the only global instance // Access the only global instance
static PyPassManagerPtr GetInstance(); static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default; virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad, void Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
bool run_only_once); bool run_only_once);
void Unregiste(const std::string &pass_name); void Unregister(const std::string &pass_name);
void GenNewParameter(const PatternPtr &parameter); void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase); PassGroupPtr GetPassGroup(Phase phase);
MatchResultPtr GetMatchResult() { return res_; } MatchResultPtr GetMatchResult() { return res_; }

View File

@ -28,6 +28,7 @@ __all__ = [
"Imm" "Imm"
] ]
class OneOf(OneOf_): class OneOf(OneOf_):
r""" r"""
Express a pattern which allows a list of patterns. Express a pattern which allows a list of patterns.
@ -51,6 +52,7 @@ class OneOf(OneOf_):
else: else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_): class Prim(Prim_):
r""" r"""
Express a pattern of certain primitive type(s). Express a pattern of certain primitive type(s).
@ -95,6 +97,7 @@ class Prim(Prim_):
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
Prim_.__init__(self, self.types, self.name) Prim_.__init__(self, self.types, self.name)
class Call(Call_): class Call(Call_):
r""" r"""
Express a primitive CNode. Express a primitive CNode.
@ -124,6 +127,7 @@ class Call(Call_):
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
Call_.__init__(self, self.prim_pattern, self.inputs) Call_.__init__(self, self.prim_pattern, self.inputs)
class NoneOf(NoneOf_): class NoneOf(NoneOf_):
r""" r"""
Express a pattern which forbids a list of patterns. Express a pattern which forbids a list of patterns.
@ -134,7 +138,7 @@ class NoneOf(NoneOf_):
def __init__(self, patterns=None): def __init__(self, patterns=None):
r""" r"""
Args: Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element
should be one of the exposed Pattern instance. should be one of the exposed Pattern instance.
Raises: Raises:
@ -150,6 +154,7 @@ class NoneOf(NoneOf_):
else: else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_): class NewTensor(NewTensor_):
r""" r"""
New Tensor to be used in the target. New Tensor to be used in the target.
@ -157,7 +162,7 @@ class NewTensor(NewTensor_):
def __init__(self, input_tensor): def __init__(self, input_tensor):
r""" r"""
Args: Args:
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target.
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
@ -168,6 +173,7 @@ class NewTensor(NewTensor_):
else: else:
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}") raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_): class NewParameter(NewParameter_):
r""" r"""
New Parameter to be used in the target. New Parameter to be used in the target.
@ -175,10 +181,10 @@ class NewParameter(NewParameter_):
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False): def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
r""" r"""
Args: Args:
para_name(str): name for the new Parameter para_name(str): name for the new Parameter.
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter.
requires_grad(bool): True if the parameter requires gradient. Default: True requires_grad(bool): True if the parameter requires gradient. Default: True.
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False layerwise_parallel(bool): switch for layerwise parallel mode. Default: False.
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.

View File

@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Reference for python pass registration.""" """Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ from .python_pass_register import register_pass, unregister_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt set_reopt
__all__ = [ __all__ = [
"registe_pass", "register_pass",
"unregiste_pass", "unregister_pass",
"gen_new_parameter", "gen_new_parameter",
"cancel_new_parameter", "cancel_new_parameter",
"set_renorm", "set_renorm",

View File

@ -17,18 +17,19 @@ from inspect import isfunction
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
from mindspore._c_expression import PyPassManager_ from mindspore._c_expression import PyPassManager_
__all__ = [ __all__ = [
"registe_pass", "register_pass",
"unregiste_pass", "unregister_pass",
"gen_new_parameter", "gen_new_parameter",
"cancel_new_parameter", "cancel_new_parameter",
"set_renorm", "set_renorm",
"set_reopt" "set_reopt"
] ]
class PyPassManager(PyPassManager_): class PyPassManager(PyPassManager_):
r""" r"""
Used to registe and unregiste python passes which can be used to alter graphs. Used to register and unregister python passes which can be used to alter graphs.
Args: Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
@ -46,7 +47,7 @@ class PyPassManager(PyPassManager_):
self.run_only_once_ = run_only_once self.run_only_once_ = run_only_once
PyPassManager_.__init__(self) PyPassManager_.__init__(self)
def registe(self, py_pass): def register(self, py_pass):
if not isfunction(py_pass): if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass() pattern, target = py_pass()
@ -55,19 +56,19 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern): if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_) super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregiste(self, py_pass): def unregister(self, py_pass):
if isinstance(py_pass, str): if isinstance(py_pass, str):
super().unregiste(py_pass) super().unregister(py_pass)
return return
if isfunction(py_pass): if isfunction(py_pass):
super().unregiste(py_pass.__name__) super().unregister(py_pass.__name__)
return return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass): def __call__(self, py_pass):
self.registe(py_pass) self.register(py_pass)
return py_pass return py_pass
def gen_new_parameter(self, pattern): def gen_new_parameter(self, pattern):
@ -85,36 +86,42 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt) super().set_reopt(do_reopt)
def registe_pass(requires_grad=True, run_only_once=False):
def register_pass(requires_grad=True, run_only_once=False):
""" """
Registe python pass to specified pipeline phase which would be used in compilation. Register python pass to specified pipeline phase which would be used in compilation.
Args: Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default:
False.
Returns: Returns:
This function should be used as a decorator, return the decoratorated pass function. This function should be used as a decorator, return the decoratorated pass function.
Examples: Examples:
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf >>> from mindspore.graph_utils.graph_pattern import Call, Any
>>> @registe_pass() >>> from mindspore.ops import operations as P
>>> @register_pass()
>>> def toy_pass(): >>> def toy_pass():
>>> pattern = IsPrimTypeOf("ReLU") >>> x = Any()
>>> target = IsPrimTypeOf("ReLU6") >>> pattern = Call(P.Softmax(), [x])
>>> target = Call(P.ReLU(), [x])
>>> return pattern, target >>> return pattern, target
""" """
return PyPassManager(requires_grad, run_only_once) return PyPassManager(requires_grad, run_only_once)
def unregiste_pass(py_pass):
def unregister_pass(py_pass):
""" """
Unregiste python pass. Unregister python pass.
Args: Args:
py_pass(Union(str, function)): target python pass to unregiste. py_pass(Union(str, function)): target python pass to unregister.
""" """
ppm = PyPassManager() ppm = PyPassManager()
ppm.unregiste(py_pass) ppm.unregister(py_pass)
def gen_new_parameter(pattern): def gen_new_parameter(pattern):
""" """
@ -123,8 +130,8 @@ def gen_new_parameter(pattern):
NOTE: NOTE:
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
gen_new_parameter, every pass match would build a new Parameter. gen_new_parameter, every pass match would build a new Parameter.
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would This would register a pass to add new parameter in the compilation pipeline, so later compilation would
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call ALSO add this parameter unless the pass is unregistered. To unregister this pass, call
cancel_new_parameter(pattern) cancel_new_parameter(pattern)
Args: Args:
@ -142,9 +149,10 @@ def gen_new_parameter(pattern):
ppm = PyPassManager() ppm = PyPassManager()
ppm.gen_new_parameter(pattern) ppm.gen_new_parameter(pattern)
def cancel_new_parameter(pattern): def cancel_new_parameter(pattern):
""" """
Use with gen_new_parameter to unregiste gen_new_parameter pass. Use with gen_new_parameter to unregister gen_new_parameter pass.
Args: Args:
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
@ -160,7 +168,8 @@ def cancel_new_parameter(pattern):
if not isinstance(pattern, NewParameter): if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
ppm = PyPassManager() ppm = PyPassManager()
ppm.unregiste(pattern.para_name) ppm.unregister(pattern.para_name)
def set_renorm(should_renorm): def set_renorm(should_renorm):
""" """
@ -176,6 +185,7 @@ def set_renorm(should_renorm):
ppm = PyPassManager() ppm = PyPassManager()
ppm.set_renorm(should_renorm) ppm.set_renorm(should_renorm)
def set_reopt(do_reopt): def set_reopt(do_reopt):
""" """
Set whether or not to do optimization after modified graph in python pass(es). Set whether or not to do optimization after modified graph in python pass(es).

View File

@ -20,7 +20,7 @@ from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants from mindspore.ops import _constants as Constants
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ from mindspore.graph_utils.python_pass import register_pass, unregister_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_ from mindspore._c_expression import generate_key, Executor_
@ -48,7 +48,7 @@ def test_softmax_relu():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @register_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = Any() x = Any()
pattern = Call(P.Softmax(), [x]) pattern = Call(P.Softmax(), [x])
@ -56,7 +56,7 @@ def test_softmax_relu():
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass) unregister_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
@ -64,7 +64,7 @@ def test_prim():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @register_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = Any() x = Any()
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()]) sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
@ -73,7 +73,7 @@ def test_prim():
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass) unregister_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
@ -87,7 +87,7 @@ def test_softmax_relu_sigmoid():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @register_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = Any() x = Any()
softmax_pattern = Prim(P.Softmax()) softmax_pattern = Prim(P.Softmax())
@ -99,7 +99,7 @@ def test_softmax_relu_sigmoid():
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass) unregister_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr assert "ReLU" in transformed_repr
assert "Sigmoid" in transformed_repr assert "Sigmoid" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
@ -112,7 +112,7 @@ def test_isin_pattern_0():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @register_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = Any() x = Any()
softmax_pattern = Prim(P.Softmax()) softmax_pattern = Prim(P.Softmax())
@ -125,7 +125,7 @@ def test_isin_pattern_0():
target = Call(relu6_pattern, [x]) target = Call(relu6_pattern, [x])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass) unregister_pass(softmax_relu_pass)
assert "ReLU6" in transformed_repr assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
@ -136,7 +136,7 @@ def test_isin_pattern_1():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @register_pass(run_only_once=True)
def softmax_neg_pass(): def softmax_neg_pass():
x = Any() x = Any()
softmax_pattern = Prim(P.Softmax()) softmax_pattern = Prim(P.Softmax())
@ -149,7 +149,7 @@ def test_isin_pattern_1():
target = Call(neg_ops, [pattern]) target = Call(neg_ops, [pattern])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
unregiste_pass(softmax_neg_pass) unregister_pass(softmax_neg_pass)
assert "Neg" in transformed_repr assert "Neg" in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
@ -177,7 +177,7 @@ def test_isnot_pattern_0():
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
conv_bn_model = ConvBN() conv_bn_model = ConvBN()
@registe_pass(requires_grad=False, run_only_once=True) @register_pass(requires_grad=False, run_only_once=True)
def single_bn_pass(): def single_bn_pass():
""" """
Sub a BN which does NOT take Conv as inputs to ReLU6. Sub a BN which does NOT take Conv as inputs to ReLU6.
@ -189,7 +189,7 @@ def test_isnot_pattern_0():
target = Call(P.ReLU6(), [pattern_0]) target = Call(P.ReLU6(), [pattern_0])
return pattern, target return pattern, target
@registe_pass(requires_grad=False, run_only_once=True) @register_pass(requires_grad=False, run_only_once=True)
def bn_pass(): def bn_pass():
""" """
Sub a BN to Softmax. Sub a BN to Softmax.
@ -199,8 +199,8 @@ def test_isnot_pattern_0():
return pattern, target return pattern, target
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
unregiste_pass(single_bn_pass) unregister_pass(single_bn_pass)
unregiste_pass(bn_pass) unregister_pass(bn_pass)
assert "ReLU6" not in transformed_repr assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
set_renorm(True) set_renorm(True)
@ -213,7 +213,7 @@ def test_isnot_pattern_1():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @register_pass(run_only_once=True)
def single_bn_pass(): def single_bn_pass():
""" """
Sub a BN which does NOT take MatMul as inputs to ReLU6. Sub a BN which does NOT take MatMul as inputs to ReLU6.
@ -227,7 +227,7 @@ def test_isnot_pattern_1():
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(single_bn_pass) unregister_pass(single_bn_pass)
assert "ReLU6" in transformed_repr assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
@ -240,7 +240,7 @@ def test_newtensor_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(requires_grad=False, run_only_once=True) @register_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass(): def softmax_addn_pass():
x = Any() x = Any()
pattern = Call(P.Softmax(), [x]) pattern = Call(P.Softmax(), [x])
@ -250,7 +250,7 @@ def test_newtensor_pattern():
target = Call(P.AddN(), [x, new_weight]) target = Call(P.AddN(), [x, new_weight])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_addn_pass) unregister_pass(softmax_addn_pass)
assert "AddN" in transformed_repr assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
set_renorm(True) set_renorm(True)
@ -264,7 +264,7 @@ def test_newparameter_pattern():
set_renorm(False) set_renorm(False)
set_reopt(False) set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True) @register_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass(): def softmax_addn_pass():
x = Any() x = Any()
pattern = Call(P.Softmax(), [x]) pattern = Call(P.Softmax(), [x])
@ -277,7 +277,7 @@ def test_newparameter_pattern():
target = Call("MakeTuple", [target_0]) target = Call("MakeTuple", [target_0])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_addn_pass) unregister_pass(softmax_addn_pass)
assert "MatMul" in transformed_repr assert "MatMul" in transformed_repr
assert "MakeTuple" in transformed_repr assert "MakeTuple" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
@ -291,7 +291,7 @@ def test_imm_target():
set_renorm(False) set_renorm(False)
set_reopt(False) set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True) @register_pass(requires_grad=False, run_only_once=True)
def softmax_pass(): def softmax_pass():
x = Any() x = Any()
pattern = Call(P.Softmax(), [x]) pattern = Call(P.Softmax(), [x])
@ -300,7 +300,7 @@ def test_imm_target():
target = Call(Constants.kTupleGetItem, [target_0, imm]) target = Call(Constants.kTupleGetItem, [target_0, imm])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_pass) unregister_pass(softmax_pass)
assert "MakeTuple" in transformed_repr assert "MakeTuple" in transformed_repr
assert Constants.kTupleGetItem in transformed_repr assert Constants.kTupleGetItem in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
@ -317,7 +317,7 @@ def test_gen_new_parameter():
set_renorm(False) set_renorm(False)
set_reopt(False) set_reopt(False)
gen_new_parameter(new_para) gen_new_parameter(new_para)
@registe_pass(requires_grad=False, run_only_once=True) @register_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass(): def softmax_make_tuple_pass():
x = Any() x = Any()
softmax = P.Softmax() softmax = P.Softmax()
@ -327,7 +327,7 @@ def test_gen_new_parameter():
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
assert "Merlin" in transformed_repr assert "Merlin" in transformed_repr
unregiste_pass(softmax_make_tuple_pass) unregister_pass(softmax_make_tuple_pass)
cancel_new_parameter(new_para) cancel_new_parameter(new_para)
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
assert "Merlin" not in transformed_repr assert "Merlin" not in transformed_repr