forked from mindspore-Ecosystem/mindspore
!16466 Fix code check of python pass
From: @irmo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
9d423638b3
|
@ -48,8 +48,8 @@ PyPassManager::PyPassManager() {
|
|||
res_ = std::make_shared<MatchResult>();
|
||||
}
|
||||
|
||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
bool requires_grad, bool run_only_once) {
|
||||
void PyPassManager::Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
bool requires_grad, bool run_only_once) {
|
||||
PassGroupPtr cur_pg;
|
||||
if (requires_grad) {
|
||||
cur_pg = GetPassGroup(Phase::PREAD);
|
||||
|
@ -65,7 +65,7 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
|
|||
cur_pg->AddPass(new_pass);
|
||||
}
|
||||
|
||||
void PyPassManager::Unregiste(const std::string &pass_name) {
|
||||
void PyPassManager::Unregister(const std::string &pass_name) {
|
||||
auto opt_pm = GetPassGroup(Phase::OPT);
|
||||
if (!opt_pm->DeletePass(pass_name)) {
|
||||
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
|
||||
|
@ -101,8 +101,8 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
|
||||
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
|
||||
.def(py::init([]() { return PyPassManager::GetInstance(); }))
|
||||
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
||||
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
|
||||
.def("register", &PyPassManager::Register, "Register python pass")
|
||||
.def("unregister", &PyPassManager::Unregister, "Unregister Python Pass")
|
||||
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
||||
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
|
||||
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
|
||||
|
|
|
@ -52,9 +52,9 @@ class PyPassManager {
|
|||
// Access the only global instance
|
||||
static PyPassManagerPtr GetInstance();
|
||||
virtual ~PyPassManager() = default;
|
||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
|
||||
bool run_only_once);
|
||||
void Unregiste(const std::string &pass_name);
|
||||
void Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
|
||||
bool run_only_once);
|
||||
void Unregister(const std::string &pass_name);
|
||||
void GenNewParameter(const PatternPtr ¶meter);
|
||||
PassGroupPtr GetPassGroup(Phase phase);
|
||||
MatchResultPtr GetMatchResult() { return res_; }
|
||||
|
|
|
@ -28,6 +28,7 @@ __all__ = [
|
|||
"Imm"
|
||||
]
|
||||
|
||||
|
||||
class OneOf(OneOf_):
|
||||
r"""
|
||||
Express a pattern which allows a list of patterns.
|
||||
|
@ -51,6 +52,7 @@ class OneOf(OneOf_):
|
|||
else:
|
||||
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
||||
|
||||
|
||||
class Prim(Prim_):
|
||||
r"""
|
||||
Express a pattern of certain primitive type(s).
|
||||
|
@ -95,6 +97,7 @@ class Prim(Prim_):
|
|||
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
||||
Prim_.__init__(self, self.types, self.name)
|
||||
|
||||
|
||||
class Call(Call_):
|
||||
r"""
|
||||
Express a primitive CNode.
|
||||
|
@ -124,6 +127,7 @@ class Call(Call_):
|
|||
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
||||
Call_.__init__(self, self.prim_pattern, self.inputs)
|
||||
|
||||
|
||||
class NoneOf(NoneOf_):
|
||||
r"""
|
||||
Express a pattern which forbids a list of patterns.
|
||||
|
@ -134,7 +138,7 @@ class NoneOf(NoneOf_):
|
|||
def __init__(self, patterns=None):
|
||||
r"""
|
||||
Args:
|
||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
|
||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element
|
||||
should be one of the exposed Pattern instance.
|
||||
|
||||
Raises:
|
||||
|
@ -150,6 +154,7 @@ class NoneOf(NoneOf_):
|
|||
else:
|
||||
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
||||
|
||||
|
||||
class NewTensor(NewTensor_):
|
||||
r"""
|
||||
New Tensor to be used in the target.
|
||||
|
@ -157,7 +162,7 @@ class NewTensor(NewTensor_):
|
|||
def __init__(self, input_tensor):
|
||||
r"""
|
||||
Args:
|
||||
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
|
||||
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
|
@ -168,6 +173,7 @@ class NewTensor(NewTensor_):
|
|||
else:
|
||||
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
|
||||
|
||||
|
||||
class NewParameter(NewParameter_):
|
||||
r"""
|
||||
New Parameter to be used in the target.
|
||||
|
@ -175,10 +181,10 @@ class NewParameter(NewParameter_):
|
|||
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
|
||||
r"""
|
||||
Args:
|
||||
para_name(str): name for the new Parameter
|
||||
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
|
||||
requires_grad(bool): True if the parameter requires gradient. Default: True
|
||||
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
|
||||
para_name(str): name for the new Parameter.
|
||||
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter.
|
||||
requires_grad(bool): True if the parameter requires gradient. Default: True.
|
||||
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Reference for python pass registration."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||
from .python_pass_register import register_pass, unregister_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||
set_reopt
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"register_pass",
|
||||
"unregister_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm",
|
||||
|
|
|
@ -17,18 +17,19 @@ from inspect import isfunction
|
|||
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
||||
from mindspore._c_expression import PyPassManager_
|
||||
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"register_pass",
|
||||
"unregister_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
]
|
||||
|
||||
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
Used to register and unregister python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
||||
|
@ -46,7 +47,7 @@ class PyPassManager(PyPassManager_):
|
|||
self.run_only_once_ = run_only_once
|
||||
PyPassManager_.__init__(self)
|
||||
|
||||
def registe(self, py_pass):
|
||||
def register(self, py_pass):
|
||||
if not isfunction(py_pass):
|
||||
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
|
||||
pattern, target = py_pass()
|
||||
|
@ -55,19 +56,19 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
|
||||
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
|
||||
|
||||
def unregiste(self, py_pass):
|
||||
def unregister(self, py_pass):
|
||||
if isinstance(py_pass, str):
|
||||
super().unregiste(py_pass)
|
||||
super().unregister(py_pass)
|
||||
return
|
||||
if isfunction(py_pass):
|
||||
super().unregiste(py_pass.__name__)
|
||||
super().unregister(py_pass.__name__)
|
||||
return
|
||||
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||
|
||||
def __call__(self, py_pass):
|
||||
self.registe(py_pass)
|
||||
self.register(py_pass)
|
||||
return py_pass
|
||||
|
||||
def gen_new_parameter(self, pattern):
|
||||
|
@ -85,36 +86,42 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
|
||||
super().set_reopt(do_reopt)
|
||||
|
||||
def registe_pass(requires_grad=True, run_only_once=False):
|
||||
|
||||
def register_pass(requires_grad=True, run_only_once=False):
|
||||
"""
|
||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||
Register python pass to specified pipeline phase which would be used in compilation.
|
||||
|
||||
Args:
|
||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True.
|
||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default:
|
||||
False.
|
||||
|
||||
Returns:
|
||||
This function should be used as a decorator, return the decoratorated pass function.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
|
||||
>>> @registe_pass()
|
||||
>>> from mindspore.graph_utils.graph_pattern import Call, Any
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> @register_pass()
|
||||
>>> def toy_pass():
|
||||
>>> pattern = IsPrimTypeOf("ReLU")
|
||||
>>> target = IsPrimTypeOf("ReLU6")
|
||||
>>> x = Any()
|
||||
>>> pattern = Call(P.Softmax(), [x])
|
||||
>>> target = Call(P.ReLU(), [x])
|
||||
>>> return pattern, target
|
||||
"""
|
||||
return PyPassManager(requires_grad, run_only_once)
|
||||
|
||||
def unregiste_pass(py_pass):
|
||||
|
||||
def unregister_pass(py_pass):
|
||||
"""
|
||||
Unregiste python pass.
|
||||
Unregister python pass.
|
||||
|
||||
Args:
|
||||
py_pass(Union(str, function)): target python pass to unregiste.
|
||||
py_pass(Union(str, function)): target python pass to unregister.
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(py_pass)
|
||||
ppm.unregister(py_pass)
|
||||
|
||||
|
||||
def gen_new_parameter(pattern):
|
||||
"""
|
||||
|
@ -123,8 +130,8 @@ def gen_new_parameter(pattern):
|
|||
NOTE:
|
||||
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
|
||||
gen_new_parameter, every pass match would build a new Parameter.
|
||||
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
|
||||
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
|
||||
This would register a pass to add new parameter in the compilation pipeline, so later compilation would
|
||||
ALSO add this parameter unless the pass is unregistered. To unregister this pass, call
|
||||
cancel_new_parameter(pattern)
|
||||
|
||||
Args:
|
||||
|
@ -142,9 +149,10 @@ def gen_new_parameter(pattern):
|
|||
ppm = PyPassManager()
|
||||
ppm.gen_new_parameter(pattern)
|
||||
|
||||
|
||||
def cancel_new_parameter(pattern):
|
||||
"""
|
||||
Use with gen_new_parameter to unregiste gen_new_parameter pass.
|
||||
Use with gen_new_parameter to unregister gen_new_parameter pass.
|
||||
|
||||
Args:
|
||||
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
|
||||
|
@ -160,7 +168,8 @@ def cancel_new_parameter(pattern):
|
|||
if not isinstance(pattern, NewParameter):
|
||||
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(pattern.para_name)
|
||||
ppm.unregister(pattern.para_name)
|
||||
|
||||
|
||||
def set_renorm(should_renorm):
|
||||
"""
|
||||
|
@ -176,6 +185,7 @@ def set_renorm(should_renorm):
|
|||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
||||
|
||||
|
||||
def set_reopt(do_reopt):
|
||||
"""
|
||||
Set whether or not to do optimization after modified graph in python pass(es).
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import context
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import _constants as Constants
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||
from mindspore.graph_utils.python_pass import register_pass, unregister_pass, set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter, set_reopt
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
|
@ -48,7 +48,7 @@ def test_softmax_relu():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@register_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -56,7 +56,7 @@ def test_softmax_relu():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
unregister_pass(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
|
@ -64,7 +64,7 @@ def test_prim():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@register_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
|
||||
|
@ -73,7 +73,7 @@ def test_prim():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
unregister_pass(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
|
@ -87,7 +87,7 @@ def test_softmax_relu_sigmoid():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@register_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
|
@ -99,7 +99,7 @@ def test_softmax_relu_sigmoid():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
unregister_pass(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Sigmoid" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
@ -112,7 +112,7 @@ def test_isin_pattern_0():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@register_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
|
@ -125,7 +125,7 @@ def test_isin_pattern_0():
|
|||
target = Call(relu6_pattern, [x])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
unregister_pass(softmax_relu_pass)
|
||||
assert "ReLU6" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
|
@ -136,7 +136,7 @@ def test_isin_pattern_1():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@register_pass(run_only_once=True)
|
||||
def softmax_neg_pass():
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
|
@ -149,7 +149,7 @@ def test_isin_pattern_1():
|
|||
target = Call(neg_ops, [pattern])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||
unregiste_pass(softmax_neg_pass)
|
||||
unregister_pass(softmax_neg_pass)
|
||||
assert "Neg" in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
||||
|
@ -177,7 +177,7 @@ def test_isnot_pattern_0():
|
|||
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
|
||||
conv_bn_model = ConvBN()
|
||||
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
@register_pass(requires_grad=False, run_only_once=True)
|
||||
def single_bn_pass():
|
||||
"""
|
||||
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
||||
|
@ -189,7 +189,7 @@ def test_isnot_pattern_0():
|
|||
target = Call(P.ReLU6(), [pattern_0])
|
||||
return pattern, target
|
||||
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
@register_pass(requires_grad=False, run_only_once=True)
|
||||
def bn_pass():
|
||||
"""
|
||||
Sub a BN to Softmax.
|
||||
|
@ -199,8 +199,8 @@ def test_isnot_pattern_0():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
||||
unregiste_pass(single_bn_pass)
|
||||
unregiste_pass(bn_pass)
|
||||
unregister_pass(single_bn_pass)
|
||||
unregister_pass(bn_pass)
|
||||
assert "ReLU6" not in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
set_renorm(True)
|
||||
|
@ -213,7 +213,7 @@ def test_isnot_pattern_1():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
@register_pass(run_only_once=True)
|
||||
def single_bn_pass():
|
||||
"""
|
||||
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
||||
|
@ -227,7 +227,7 @@ def test_isnot_pattern_1():
|
|||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
unregiste_pass(single_bn_pass)
|
||||
unregister_pass(single_bn_pass)
|
||||
assert "ReLU6" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
|
@ -240,7 +240,7 @@ def test_newtensor_pattern():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
@register_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -250,7 +250,7 @@ def test_newtensor_pattern():
|
|||
target = Call(P.AddN(), [x, new_weight])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
unregister_pass(softmax_addn_pass)
|
||||
assert "AddN" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
set_renorm(True)
|
||||
|
@ -264,7 +264,7 @@ def test_newparameter_pattern():
|
|||
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
@register_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -277,7 +277,7 @@ def test_newparameter_pattern():
|
|||
target = Call("MakeTuple", [target_0])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
unregister_pass(softmax_addn_pass)
|
||||
assert "MatMul" in transformed_repr
|
||||
assert "MakeTuple" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
@ -291,7 +291,7 @@ def test_imm_target():
|
|||
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
@register_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
@ -300,7 +300,7 @@ def test_imm_target():
|
|||
target = Call(Constants.kTupleGetItem, [target_0, imm])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
unregiste_pass(softmax_pass)
|
||||
unregister_pass(softmax_pass)
|
||||
assert "MakeTuple" in transformed_repr
|
||||
assert Constants.kTupleGetItem in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
@ -317,7 +317,7 @@ def test_gen_new_parameter():
|
|||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
gen_new_parameter(new_para)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
@register_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_make_tuple_pass():
|
||||
x = Any()
|
||||
softmax = P.Softmax()
|
||||
|
@ -327,7 +327,7 @@ def test_gen_new_parameter():
|
|||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
assert "Merlin" in transformed_repr
|
||||
unregiste_pass(softmax_make_tuple_pass)
|
||||
unregister_pass(softmax_make_tuple_pass)
|
||||
cancel_new_parameter(new_para)
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
assert "Merlin" not in transformed_repr
|
||||
|
|
Loading…
Reference in New Issue