forked from mindspore-Ecosystem/mindspore
!16466 Fix code check of python pass
From: @irmo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
9d423638b3
|
@ -48,8 +48,8 @@ PyPassManager::PyPassManager() {
|
||||||
res_ = std::make_shared<MatchResult>();
|
res_ = std::make_shared<MatchResult>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
void PyPassManager::Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||||
bool requires_grad, bool run_only_once) {
|
bool requires_grad, bool run_only_once) {
|
||||||
PassGroupPtr cur_pg;
|
PassGroupPtr cur_pg;
|
||||||
if (requires_grad) {
|
if (requires_grad) {
|
||||||
cur_pg = GetPassGroup(Phase::PREAD);
|
cur_pg = GetPassGroup(Phase::PREAD);
|
||||||
|
@ -65,7 +65,7 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
|
||||||
cur_pg->AddPass(new_pass);
|
cur_pg->AddPass(new_pass);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyPassManager::Unregiste(const std::string &pass_name) {
|
void PyPassManager::Unregister(const std::string &pass_name) {
|
||||||
auto opt_pm = GetPassGroup(Phase::OPT);
|
auto opt_pm = GetPassGroup(Phase::OPT);
|
||||||
if (!opt_pm->DeletePass(pass_name)) {
|
if (!opt_pm->DeletePass(pass_name)) {
|
||||||
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
|
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
|
||||||
|
@ -101,8 +101,8 @@ REGISTER_PYBIND_DEFINE(
|
||||||
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
|
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
|
||||||
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
|
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
|
||||||
.def(py::init([]() { return PyPassManager::GetInstance(); }))
|
.def(py::init([]() { return PyPassManager::GetInstance(); }))
|
||||||
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
.def("register", &PyPassManager::Register, "Register python pass")
|
||||||
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
|
.def("unregister", &PyPassManager::Unregister, "Unregister Python Pass")
|
||||||
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
||||||
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
|
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
|
||||||
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
|
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
|
||||||
|
|
|
@ -52,9 +52,9 @@ class PyPassManager {
|
||||||
// Access the only global instance
|
// Access the only global instance
|
||||||
static PyPassManagerPtr GetInstance();
|
static PyPassManagerPtr GetInstance();
|
||||||
virtual ~PyPassManager() = default;
|
virtual ~PyPassManager() = default;
|
||||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
|
void Register(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
|
||||||
bool run_only_once);
|
bool run_only_once);
|
||||||
void Unregiste(const std::string &pass_name);
|
void Unregister(const std::string &pass_name);
|
||||||
void GenNewParameter(const PatternPtr ¶meter);
|
void GenNewParameter(const PatternPtr ¶meter);
|
||||||
PassGroupPtr GetPassGroup(Phase phase);
|
PassGroupPtr GetPassGroup(Phase phase);
|
||||||
MatchResultPtr GetMatchResult() { return res_; }
|
MatchResultPtr GetMatchResult() { return res_; }
|
||||||
|
|
|
@ -28,6 +28,7 @@ __all__ = [
|
||||||
"Imm"
|
"Imm"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class OneOf(OneOf_):
|
class OneOf(OneOf_):
|
||||||
r"""
|
r"""
|
||||||
Express a pattern which allows a list of patterns.
|
Express a pattern which allows a list of patterns.
|
||||||
|
@ -51,6 +52,7 @@ class OneOf(OneOf_):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
||||||
|
|
||||||
|
|
||||||
class Prim(Prim_):
|
class Prim(Prim_):
|
||||||
r"""
|
r"""
|
||||||
Express a pattern of certain primitive type(s).
|
Express a pattern of certain primitive type(s).
|
||||||
|
@ -95,6 +97,7 @@ class Prim(Prim_):
|
||||||
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
||||||
Prim_.__init__(self, self.types, self.name)
|
Prim_.__init__(self, self.types, self.name)
|
||||||
|
|
||||||
|
|
||||||
class Call(Call_):
|
class Call(Call_):
|
||||||
r"""
|
r"""
|
||||||
Express a primitive CNode.
|
Express a primitive CNode.
|
||||||
|
@ -124,6 +127,7 @@ class Call(Call_):
|
||||||
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
||||||
Call_.__init__(self, self.prim_pattern, self.inputs)
|
Call_.__init__(self, self.prim_pattern, self.inputs)
|
||||||
|
|
||||||
|
|
||||||
class NoneOf(NoneOf_):
|
class NoneOf(NoneOf_):
|
||||||
r"""
|
r"""
|
||||||
Express a pattern which forbids a list of patterns.
|
Express a pattern which forbids a list of patterns.
|
||||||
|
@ -134,7 +138,7 @@ class NoneOf(NoneOf_):
|
||||||
def __init__(self, patterns=None):
|
def __init__(self, patterns=None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
|
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element
|
||||||
should be one of the exposed Pattern instance.
|
should be one of the exposed Pattern instance.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -150,6 +154,7 @@ class NoneOf(NoneOf_):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
||||||
|
|
||||||
|
|
||||||
class NewTensor(NewTensor_):
|
class NewTensor(NewTensor_):
|
||||||
r"""
|
r"""
|
||||||
New Tensor to be used in the target.
|
New Tensor to be used in the target.
|
||||||
|
@ -157,7 +162,7 @@ class NewTensor(NewTensor_):
|
||||||
def __init__(self, input_tensor):
|
def __init__(self, input_tensor):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
|
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
|
@ -168,6 +173,7 @@ class NewTensor(NewTensor_):
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
|
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
|
||||||
|
|
||||||
|
|
||||||
class NewParameter(NewParameter_):
|
class NewParameter(NewParameter_):
|
||||||
r"""
|
r"""
|
||||||
New Parameter to be used in the target.
|
New Parameter to be used in the target.
|
||||||
|
@ -175,10 +181,10 @@ class NewParameter(NewParameter_):
|
||||||
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
|
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
para_name(str): name for the new Parameter
|
para_name(str): name for the new Parameter.
|
||||||
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
|
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter.
|
||||||
requires_grad(bool): True if the parameter requires gradient. Default: True
|
requires_grad(bool): True if the parameter requires gradient. Default: True.
|
||||||
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
|
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
|
|
|
@ -13,12 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Reference for python pass registration."""
|
"""Reference for python pass registration."""
|
||||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
from .python_pass_register import register_pass, unregister_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||||
set_reopt
|
set_reopt
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"registe_pass",
|
"register_pass",
|
||||||
"unregiste_pass",
|
"unregister_pass",
|
||||||
"gen_new_parameter",
|
"gen_new_parameter",
|
||||||
"cancel_new_parameter",
|
"cancel_new_parameter",
|
||||||
"set_renorm",
|
"set_renorm",
|
||||||
|
|
|
@ -17,18 +17,19 @@ from inspect import isfunction
|
||||||
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
||||||
from mindspore._c_expression import PyPassManager_
|
from mindspore._c_expression import PyPassManager_
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"registe_pass",
|
"register_pass",
|
||||||
"unregiste_pass",
|
"unregister_pass",
|
||||||
"gen_new_parameter",
|
"gen_new_parameter",
|
||||||
"cancel_new_parameter",
|
"cancel_new_parameter",
|
||||||
"set_renorm",
|
"set_renorm",
|
||||||
"set_reopt"
|
"set_reopt"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PyPassManager(PyPassManager_):
|
class PyPassManager(PyPassManager_):
|
||||||
r"""
|
r"""
|
||||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
Used to register and unregister python passes which can be used to alter graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
||||||
|
@ -46,7 +47,7 @@ class PyPassManager(PyPassManager_):
|
||||||
self.run_only_once_ = run_only_once
|
self.run_only_once_ = run_only_once
|
||||||
PyPassManager_.__init__(self)
|
PyPassManager_.__init__(self)
|
||||||
|
|
||||||
def registe(self, py_pass):
|
def register(self, py_pass):
|
||||||
if not isfunction(py_pass):
|
if not isfunction(py_pass):
|
||||||
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
|
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
|
||||||
pattern, target = py_pass()
|
pattern, target = py_pass()
|
||||||
|
@ -55,19 +56,19 @@ class PyPassManager(PyPassManager_):
|
||||||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||||
if not isinstance(target, Pattern):
|
if not isinstance(target, Pattern):
|
||||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||||
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
|
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
|
||||||
|
|
||||||
def unregiste(self, py_pass):
|
def unregister(self, py_pass):
|
||||||
if isinstance(py_pass, str):
|
if isinstance(py_pass, str):
|
||||||
super().unregiste(py_pass)
|
super().unregister(py_pass)
|
||||||
return
|
return
|
||||||
if isfunction(py_pass):
|
if isfunction(py_pass):
|
||||||
super().unregiste(py_pass.__name__)
|
super().unregister(py_pass.__name__)
|
||||||
return
|
return
|
||||||
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||||
|
|
||||||
def __call__(self, py_pass):
|
def __call__(self, py_pass):
|
||||||
self.registe(py_pass)
|
self.register(py_pass)
|
||||||
return py_pass
|
return py_pass
|
||||||
|
|
||||||
def gen_new_parameter(self, pattern):
|
def gen_new_parameter(self, pattern):
|
||||||
|
@ -85,36 +86,42 @@ class PyPassManager(PyPassManager_):
|
||||||
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
|
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
|
||||||
super().set_reopt(do_reopt)
|
super().set_reopt(do_reopt)
|
||||||
|
|
||||||
def registe_pass(requires_grad=True, run_only_once=False):
|
|
||||||
|
def register_pass(requires_grad=True, run_only_once=False):
|
||||||
"""
|
"""
|
||||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
Register python pass to specified pipeline phase which would be used in compilation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
|
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True.
|
||||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default:
|
||||||
|
False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
This function should be used as a decorator, return the decoratorated pass function.
|
This function should be used as a decorator, return the decoratorated pass function.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
|
>>> from mindspore.graph_utils.graph_pattern import Call, Any
|
||||||
>>> @registe_pass()
|
>>> from mindspore.ops import operations as P
|
||||||
|
>>> @register_pass()
|
||||||
>>> def toy_pass():
|
>>> def toy_pass():
|
||||||
>>> pattern = IsPrimTypeOf("ReLU")
|
>>> x = Any()
|
||||||
>>> target = IsPrimTypeOf("ReLU6")
|
>>> pattern = Call(P.Softmax(), [x])
|
||||||
|
>>> target = Call(P.ReLU(), [x])
|
||||||
>>> return pattern, target
|
>>> return pattern, target
|
||||||
"""
|
"""
|
||||||
return PyPassManager(requires_grad, run_only_once)
|
return PyPassManager(requires_grad, run_only_once)
|
||||||
|
|
||||||
def unregiste_pass(py_pass):
|
|
||||||
|
def unregister_pass(py_pass):
|
||||||
"""
|
"""
|
||||||
Unregiste python pass.
|
Unregister python pass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
py_pass(Union(str, function)): target python pass to unregiste.
|
py_pass(Union(str, function)): target python pass to unregister.
|
||||||
"""
|
"""
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.unregiste(py_pass)
|
ppm.unregister(py_pass)
|
||||||
|
|
||||||
|
|
||||||
def gen_new_parameter(pattern):
|
def gen_new_parameter(pattern):
|
||||||
"""
|
"""
|
||||||
|
@ -123,8 +130,8 @@ def gen_new_parameter(pattern):
|
||||||
NOTE:
|
NOTE:
|
||||||
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
|
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
|
||||||
gen_new_parameter, every pass match would build a new Parameter.
|
gen_new_parameter, every pass match would build a new Parameter.
|
||||||
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
|
This would register a pass to add new parameter in the compilation pipeline, so later compilation would
|
||||||
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
|
ALSO add this parameter unless the pass is unregistered. To unregister this pass, call
|
||||||
cancel_new_parameter(pattern)
|
cancel_new_parameter(pattern)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -142,9 +149,10 @@ def gen_new_parameter(pattern):
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.gen_new_parameter(pattern)
|
ppm.gen_new_parameter(pattern)
|
||||||
|
|
||||||
|
|
||||||
def cancel_new_parameter(pattern):
|
def cancel_new_parameter(pattern):
|
||||||
"""
|
"""
|
||||||
Use with gen_new_parameter to unregiste gen_new_parameter pass.
|
Use with gen_new_parameter to unregister gen_new_parameter pass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
|
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
|
||||||
|
@ -160,7 +168,8 @@ def cancel_new_parameter(pattern):
|
||||||
if not isinstance(pattern, NewParameter):
|
if not isinstance(pattern, NewParameter):
|
||||||
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.unregiste(pattern.para_name)
|
ppm.unregister(pattern.para_name)
|
||||||
|
|
||||||
|
|
||||||
def set_renorm(should_renorm):
|
def set_renorm(should_renorm):
|
||||||
"""
|
"""
|
||||||
|
@ -176,6 +185,7 @@ def set_renorm(should_renorm):
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.set_renorm(should_renorm)
|
ppm.set_renorm(should_renorm)
|
||||||
|
|
||||||
|
|
||||||
def set_reopt(do_reopt):
|
def set_reopt(do_reopt):
|
||||||
"""
|
"""
|
||||||
Set whether or not to do optimization after modified graph in python pass(es).
|
Set whether or not to do optimization after modified graph in python pass(es).
|
||||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import context
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import _constants as Constants
|
from mindspore.ops import _constants as Constants
|
||||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
from mindspore.graph_utils.python_pass import register_pass, unregister_pass, set_renorm, gen_new_parameter,\
|
||||||
cancel_new_parameter, set_reopt
|
cancel_new_parameter, set_reopt
|
||||||
from mindspore.common.api import _generate_pip_args
|
from mindspore.common.api import _generate_pip_args
|
||||||
from mindspore._c_expression import generate_key, Executor_
|
from mindspore._c_expression import generate_key, Executor_
|
||||||
|
@ -48,7 +48,7 @@ def test_softmax_relu():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@register_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), [x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
|
@ -56,7 +56,7 @@ def test_softmax_relu():
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_relu_pass)
|
unregister_pass(softmax_relu_pass)
|
||||||
assert "ReLU" in transformed_repr
|
assert "ReLU" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ def test_prim():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@register_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
|
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
|
||||||
|
@ -73,7 +73,7 @@ def test_prim():
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||||
unregiste_pass(softmax_relu_pass)
|
unregister_pass(softmax_relu_pass)
|
||||||
assert "ReLU" in transformed_repr
|
assert "ReLU" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ def test_softmax_relu_sigmoid():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@register_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax_pattern = Prim(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
|
@ -99,7 +99,7 @@ def test_softmax_relu_sigmoid():
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||||
unregiste_pass(softmax_relu_pass)
|
unregister_pass(softmax_relu_pass)
|
||||||
assert "ReLU" in transformed_repr
|
assert "ReLU" in transformed_repr
|
||||||
assert "Sigmoid" in transformed_repr
|
assert "Sigmoid" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
@ -112,7 +112,7 @@ def test_isin_pattern_0():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@register_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax_pattern = Prim(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
|
@ -125,7 +125,7 @@ def test_isin_pattern_0():
|
||||||
target = Call(relu6_pattern, [x])
|
target = Call(relu6_pattern, [x])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_relu_pass)
|
unregister_pass(softmax_relu_pass)
|
||||||
assert "ReLU6" in transformed_repr
|
assert "ReLU6" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ def test_isin_pattern_1():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@register_pass(run_only_once=True)
|
||||||
def softmax_neg_pass():
|
def softmax_neg_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax_pattern = Prim(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
|
@ -149,7 +149,7 @@ def test_isin_pattern_1():
|
||||||
target = Call(neg_ops, [pattern])
|
target = Call(neg_ops, [pattern])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||||
unregiste_pass(softmax_neg_pass)
|
unregister_pass(softmax_neg_pass)
|
||||||
assert "Neg" in transformed_repr
|
assert "Neg" in transformed_repr
|
||||||
assert "Softmax" in transformed_repr
|
assert "Softmax" in transformed_repr
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ def test_isnot_pattern_0():
|
||||||
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
|
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
|
||||||
conv_bn_model = ConvBN()
|
conv_bn_model = ConvBN()
|
||||||
|
|
||||||
@registe_pass(requires_grad=False, run_only_once=True)
|
@register_pass(requires_grad=False, run_only_once=True)
|
||||||
def single_bn_pass():
|
def single_bn_pass():
|
||||||
"""
|
"""
|
||||||
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
||||||
|
@ -189,7 +189,7 @@ def test_isnot_pattern_0():
|
||||||
target = Call(P.ReLU6(), [pattern_0])
|
target = Call(P.ReLU6(), [pattern_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
@registe_pass(requires_grad=False, run_only_once=True)
|
@register_pass(requires_grad=False, run_only_once=True)
|
||||||
def bn_pass():
|
def bn_pass():
|
||||||
"""
|
"""
|
||||||
Sub a BN to Softmax.
|
Sub a BN to Softmax.
|
||||||
|
@ -199,8 +199,8 @@ def test_isnot_pattern_0():
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(single_bn_pass)
|
unregister_pass(single_bn_pass)
|
||||||
unregiste_pass(bn_pass)
|
unregister_pass(bn_pass)
|
||||||
assert "ReLU6" not in transformed_repr
|
assert "ReLU6" not in transformed_repr
|
||||||
assert "Softmax" in transformed_repr
|
assert "Softmax" in transformed_repr
|
||||||
set_renorm(True)
|
set_renorm(True)
|
||||||
|
@ -213,7 +213,7 @@ def test_isnot_pattern_1():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@register_pass(run_only_once=True)
|
||||||
def single_bn_pass():
|
def single_bn_pass():
|
||||||
"""
|
"""
|
||||||
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
||||||
|
@ -227,7 +227,7 @@ def test_isnot_pattern_1():
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(single_bn_pass)
|
unregister_pass(single_bn_pass)
|
||||||
assert "ReLU6" in transformed_repr
|
assert "ReLU6" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
|
@ -240,7 +240,7 @@ def test_newtensor_pattern():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(requires_grad=False, run_only_once=True)
|
@register_pass(requires_grad=False, run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_addn_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), [x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
|
@ -250,7 +250,7 @@ def test_newtensor_pattern():
|
||||||
target = Call(P.AddN(), [x, new_weight])
|
target = Call(P.AddN(), [x, new_weight])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_addn_pass)
|
unregister_pass(softmax_addn_pass)
|
||||||
assert "AddN" in transformed_repr
|
assert "AddN" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
set_renorm(True)
|
set_renorm(True)
|
||||||
|
@ -264,7 +264,7 @@ def test_newparameter_pattern():
|
||||||
|
|
||||||
set_renorm(False)
|
set_renorm(False)
|
||||||
set_reopt(False)
|
set_reopt(False)
|
||||||
@registe_pass(requires_grad=False, run_only_once=True)
|
@register_pass(requires_grad=False, run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_addn_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), [x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
|
@ -277,7 +277,7 @@ def test_newparameter_pattern():
|
||||||
target = Call("MakeTuple", [target_0])
|
target = Call("MakeTuple", [target_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(softmax_addn_pass)
|
unregister_pass(softmax_addn_pass)
|
||||||
assert "MatMul" in transformed_repr
|
assert "MatMul" in transformed_repr
|
||||||
assert "MakeTuple" in transformed_repr
|
assert "MakeTuple" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
@ -291,7 +291,7 @@ def test_imm_target():
|
||||||
|
|
||||||
set_renorm(False)
|
set_renorm(False)
|
||||||
set_reopt(False)
|
set_reopt(False)
|
||||||
@registe_pass(requires_grad=False, run_only_once=True)
|
@register_pass(requires_grad=False, run_only_once=True)
|
||||||
def softmax_pass():
|
def softmax_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), [x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
|
@ -300,7 +300,7 @@ def test_imm_target():
|
||||||
target = Call(Constants.kTupleGetItem, [target_0, imm])
|
target = Call(Constants.kTupleGetItem, [target_0, imm])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(softmax_pass)
|
unregister_pass(softmax_pass)
|
||||||
assert "MakeTuple" in transformed_repr
|
assert "MakeTuple" in transformed_repr
|
||||||
assert Constants.kTupleGetItem in transformed_repr
|
assert Constants.kTupleGetItem in transformed_repr
|
||||||
assert "Softmax" in transformed_repr
|
assert "Softmax" in transformed_repr
|
||||||
|
@ -317,7 +317,7 @@ def test_gen_new_parameter():
|
||||||
set_renorm(False)
|
set_renorm(False)
|
||||||
set_reopt(False)
|
set_reopt(False)
|
||||||
gen_new_parameter(new_para)
|
gen_new_parameter(new_para)
|
||||||
@registe_pass(requires_grad=False, run_only_once=True)
|
@register_pass(requires_grad=False, run_only_once=True)
|
||||||
def softmax_make_tuple_pass():
|
def softmax_make_tuple_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax = P.Softmax()
|
softmax = P.Softmax()
|
||||||
|
@ -327,7 +327,7 @@ def test_gen_new_parameter():
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
assert "Merlin" in transformed_repr
|
assert "Merlin" in transformed_repr
|
||||||
unregiste_pass(softmax_make_tuple_pass)
|
unregister_pass(softmax_make_tuple_pass)
|
||||||
cancel_new_parameter(new_para)
|
cancel_new_parameter(new_para)
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
assert "Merlin" not in transformed_repr
|
assert "Merlin" not in transformed_repr
|
||||||
|
|
Loading…
Reference in New Issue