forked from mindspore-Ecosystem/mindspore
!6002 Reduce inline passes traversal by skipping inline action if no pre-ad python pass exists
Merge pull request !6002 from BowenK/pre_ad
This commit is contained in:
commit
37e3b6082f
|
@ -49,6 +49,7 @@ class PassGroup {
|
|||
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
|
||||
std::string name() const { return name_; }
|
||||
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
|
||||
size_t size() { return passes_.size(); }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
|
|
|
@ -301,7 +301,12 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
|||
return true;
|
||||
}
|
||||
|
||||
bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }
|
||||
bool OptInlineAction(const ResourcePtr &res) {
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
|
||||
return OptimizeAction(res, kInlinePasses);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
|
||||
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Reference for python pass registration."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
|
||||
_set_reopt
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||
set_reopt
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"_set_renorm",
|
||||
"_set_reopt"
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
]
|
||||
|
|
|
@ -23,8 +23,8 @@ __all__ = [
|
|||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"_set_renorm",
|
||||
"_set_reopt"
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
]
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
|
@ -162,7 +162,7 @@ def cancel_new_parameter(pattern):
|
|||
ppm = PyPassManager()
|
||||
ppm.unregiste(pattern.para_name)
|
||||
|
||||
def _set_renorm(should_renorm):
|
||||
def set_renorm(should_renorm):
|
||||
"""
|
||||
Set whether or not to do renormalization after modified graph in python pass(es).
|
||||
|
||||
|
@ -176,7 +176,7 @@ def _set_renorm(should_renorm):
|
|||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
||||
|
||||
def _set_reopt(do_reopt):
|
||||
def set_reopt(do_reopt):
|
||||
"""
|
||||
Set whether or not to do optimization after modified graph in python pass(es).
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ import mindspore.nn as nn
|
|||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter, _set_reopt
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter, set_reopt
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
||||
|
@ -157,8 +157,8 @@ def test_isnot_pattern_0():
|
|||
Test IsNot pattern which expresses the IsNot semantics.
|
||||
Case: IsNot pass failed to match
|
||||
"""
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
class ConvBN(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ConvBN, self).__init__()
|
||||
|
@ -202,7 +202,7 @@ def test_isnot_pattern_0():
|
|||
unregiste_pass(bn_pass)
|
||||
assert "ReLU6" not in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
_set_renorm(True)
|
||||
set_renorm(True)
|
||||
|
||||
def test_isnot_pattern_1():
|
||||
"""
|
||||
|
@ -234,8 +234,8 @@ def test_newtensor_pattern():
|
|||
"""
|
||||
Test NewTensor pattern in the target
|
||||
"""
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
|
@ -252,7 +252,7 @@ def test_newtensor_pattern():
|
|||
unregiste_pass(softmax_addn_pass)
|
||||
assert "AddN" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
_set_renorm(True)
|
||||
set_renorm(True)
|
||||
|
||||
def test_newparameter_pattern():
|
||||
"""
|
||||
|
@ -261,8 +261,8 @@ def test_newparameter_pattern():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
|
@ -288,8 +288,8 @@ def test_imm_target():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_pass():
|
||||
x = Any()
|
||||
|
@ -313,8 +313,8 @@ def test_gen_new_parameter():
|
|||
|
||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para = NewParameter("Merlin", default_tensor)
|
||||
_set_renorm(False)
|
||||
_set_reopt(False)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
gen_new_parameter(new_para)
|
||||
@registe_pass(requires_grad=False, run_only_once=True)
|
||||
def softmax_make_tuple_pass():
|
||||
|
|
Loading…
Reference in New Issue