forked from OSSInnovation/mindspore
!5586 Boost python pass compile and train performance
Merge pull request !5586 from BowenK/performance
This commit is contained in:
commit
1a4d3e351e
|
@ -59,6 +59,7 @@ class Pattern : public Base {
|
|||
string unique_name() const { return unique_name_; }
|
||||
vector<PatternPtr> inputs() { return inputs_; }
|
||||
virtual void reset() {}
|
||||
static void reset_gid() { g_id_ = 0; }
|
||||
|
||||
protected:
|
||||
static int g_id_;
|
||||
|
@ -213,7 +214,6 @@ class NewParameter : public Pattern {
|
|||
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
|
||||
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
||||
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
||||
// clone input tensor
|
||||
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
||||
built_ = false;
|
||||
}
|
||||
|
@ -257,7 +257,7 @@ class MatchResult {
|
|||
MatchResult() {}
|
||||
~MatchResult() = default;
|
||||
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
|
||||
PatternNodeMap _result() { return match_result_; }
|
||||
PatternNodeMap &_result() { return match_result_; }
|
||||
AnfNodePtr get_node(const PatternPtr &pattern);
|
||||
void merge(const MatchResultPtr &other_result);
|
||||
void clear() { match_result_.clear(); }
|
||||
|
|
|
@ -27,8 +27,6 @@
|
|||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "utils/info.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/draw.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -42,29 +40,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
|||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||
bool requires_grad, bool layerwise_parallel);
|
||||
|
||||
std::string GetNodeRepr(AnfNodePtr node) {
|
||||
if (node != nullptr) {
|
||||
if (node->isa<CNode>()) {
|
||||
std::string repr = "(";
|
||||
auto const &inputs = node->cast<CNodePtr>()->inputs();
|
||||
for (auto &input : inputs) {
|
||||
repr += " ";
|
||||
repr += GetNodeRepr(input);
|
||||
repr += " ";
|
||||
}
|
||||
repr += ")";
|
||||
return repr;
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
return "[Parameter]" + node->ToString();
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
return "[Value]" + GetValueNode(node)->ToString();
|
||||
}
|
||||
return node->ToString();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool IsTraversable(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
|
@ -215,23 +190,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
|||
return new_node;
|
||||
}
|
||||
|
||||
void DrawNode(string name, AnfNodePtr node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
auto new_func_graph = std::make_shared<FuncGraph>();
|
||||
new_func_graph->set_output(node, true);
|
||||
if (save_graphs) {
|
||||
auto ir_dump_path = save_graphs_path + "/" + name + ".ir";
|
||||
auto dot_dump_path = save_graphs_path + "/" + name + ".dot";
|
||||
DumpIR(ir_dump_path, new_func_graph);
|
||||
draw::Draw(dot_dump_path, new_func_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||
bool requires_grad, bool layerwise_parallel) {
|
||||
// 1. Get current cell object
|
||||
|
@ -241,12 +199,15 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor
|
|||
if (py::isinstance<py::none>(top_cell)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
|
||||
}
|
||||
// 2. New a Parameter object with the above-specified args
|
||||
// 2. Clone default_input tensor
|
||||
auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(),
|
||||
default_input->data_c(), (size_t)default_input->Size());
|
||||
// 3. New a Parameter object with the above-specified args
|
||||
py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
|
||||
py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel);
|
||||
// 3. Add the new python Parameter object to Cell's _params atttributes
|
||||
py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel);
|
||||
// 4. Add the new python Parameter object to Cell's _params atttributes
|
||||
top_cell.attr(SET_PARAM)(param_name, new_parameter);
|
||||
// 4. Set default_param for param_node
|
||||
// 5. Set default_param for param_node
|
||||
ValuePtr param_value = nullptr;
|
||||
bool converted = parse::ConvertData(new_parameter, ¶m_value, false);
|
||||
if (!converted) {
|
||||
|
@ -282,11 +243,9 @@ void Reset(PatternPtr pattern) {
|
|||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
|
||||
auto match_res = src_pattern_->match(node);
|
||||
if (match_res != nullptr) {
|
||||
MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node);
|
||||
res->merge(match_res);
|
||||
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
||||
internal::Reset(dst_pattern());
|
||||
MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
|
||||
return new_node;
|
||||
}
|
||||
internal::Reset(src_pattern());
|
||||
|
@ -303,7 +262,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
|||
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
|
||||
}
|
||||
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
|
||||
MS_LOG(DEBUG) << "Adding New parameter : " + para_name;
|
||||
auto para_node = std::make_shared<Parameter>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(para_node);
|
||||
para_node->set_name(para_name);
|
||||
|
@ -321,7 +279,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
|||
// Reflect back to Cell._params
|
||||
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
||||
new_para_pattern->layerwise_parallel());
|
||||
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
|
||||
MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
|
||||
return true;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
|
@ -334,7 +292,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
|||
for (auto &node : graph_nodes_sorted) {
|
||||
AnfNodePtr new_node = Run(func_graph, node, res);
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
internal::DrawNode(dst_pattern_->unique_name(), new_node);
|
||||
(void)manager->Replace(node, new_node);
|
||||
changes = true;
|
||||
}
|
||||
|
|
|
@ -98,7 +98,8 @@ REGISTER_PYBIND_DEFINE(
|
|||
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
||||
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
|
||||
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
||||
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph");
|
||||
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
|
||||
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
|
||||
}));
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
|
|
|
@ -60,13 +60,19 @@ class PyPassManager {
|
|||
MatchResultPtr GetMatchResult() { return res_; }
|
||||
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
||||
bool ShouldRenorm() { return should_renorm_; }
|
||||
void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; }
|
||||
bool ShouldReOpt() { return should_reopt_; }
|
||||
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
||||
pipeline::ResourcePtr GetResource() { return resource_; }
|
||||
void ClearRes();
|
||||
void ClearPipelineRes() { resource_ = nullptr; }
|
||||
void ClearPipelineRes() {
|
||||
resource_ = nullptr;
|
||||
Pattern::reset_gid();
|
||||
}
|
||||
|
||||
private:
|
||||
bool should_renorm_ = true;
|
||||
bool should_reopt_ = true;
|
||||
MatchResultPtr res_;
|
||||
pipeline::ResourcePtr resource_;
|
||||
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
|
||||
|
|
|
@ -451,35 +451,55 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
|
|||
|
||||
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
||||
|
||||
void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
||||
bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
||||
MS_EXCEPTION_IF_NULL(res->manager());
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
auto ppm = opt::python_pass::PyPassManager::GetInstance();
|
||||
ppm->SetResource(res);
|
||||
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
|
||||
MS_LOG(DEBUG) << "No match.\n";
|
||||
} else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||
MS_LOG(DEBUG) << "Entered PyStub Renorm";
|
||||
// Renomalize
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
res->set_args_spec(args_spec);
|
||||
}
|
||||
return ppm->GetPassGroup(phase)->Run(res->func_graph());
|
||||
}
|
||||
|
||||
bool ResolveActionPyStub(const ResourcePtr &res) {
|
||||
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
|
||||
bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); }
|
||||
|
||||
bool OptActionVmPyStub(const ResourcePtr &res) {
|
||||
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||
// Renomalize
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
res->set_args_spec(args_spec);
|
||||
}
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
|
||||
return VmOptimizeAction(res);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool OptActionPyStub(const ResourcePtr &res) {
|
||||
ActionPyStub(res, opt::python_pass::Phase::OPT);
|
||||
bool OptActionGePyStub(const ResourcePtr &res) {
|
||||
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||
// Renomalize
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
res->set_args_spec(args_spec);
|
||||
}
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
|
||||
return GeOptimizeAction(res);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -510,7 +530,7 @@ std::vector<ActionItem> GePipeline() {
|
|||
// optimize
|
||||
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
|
||||
// Add opt-stage python pass stub
|
||||
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
|
||||
actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
|
||||
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
return actions;
|
||||
|
@ -523,7 +543,7 @@ std::vector<ActionItem> VmPipeline() {
|
|||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
|
||||
// Add opt-stage python pass stub
|
||||
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
|
||||
actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
|
||||
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
|
|
|
@ -12,13 +12,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Top-level reference to python pass."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
|
||||
"""Reference for python pass registration."""
|
||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||
set_reopt
|
||||
|
||||
__all__ = [
|
||||
"registe_pass",
|
||||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm"
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
]
|
||||
|
|
|
@ -23,7 +23,8 @@ __all__ = [
|
|||
"unregiste_pass",
|
||||
"gen_new_parameter",
|
||||
"cancel_new_parameter",
|
||||
"set_renorm"
|
||||
"set_renorm",
|
||||
"set_reopt"
|
||||
]
|
||||
class PyPassManager(PyPassManager_):
|
||||
r"""
|
||||
|
@ -75,6 +76,11 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
||||
super().set_renorm(should_renorm)
|
||||
|
||||
def set_reopt(self, do_reopt):
|
||||
if not isinstance(do_reopt, bool):
|
||||
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
|
||||
super().set_reopt(do_reopt)
|
||||
|
||||
def registe_pass(run_only_once=False):
|
||||
"""
|
||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||
|
@ -164,3 +170,17 @@ def set_renorm(should_renorm):
|
|||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
||||
|
||||
def set_reopt(do_reopt):
|
||||
"""
|
||||
Set whether or not to do optimization after modified graph in python pass(es).
|
||||
|
||||
Args:
|
||||
do_reopt(bool): whether or not to do optimization after modified graph in python pass(es).
|
||||
|
||||
NOTE:
|
||||
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
|
||||
renormalization may BREAK the network.
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.set_reopt(do_reopt)
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import context
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||
cancel_new_parameter
|
||||
cancel_new_parameter, set_reopt
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
||||
|
@ -50,8 +50,8 @@ def test_softmax_relu():
|
|||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
target = Call(P.ReLU(), inputs=[x])
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
target = Call(P.ReLU(), [x])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
|
@ -59,6 +59,23 @@ def test_softmax_relu():
|
|||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_prim():
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
|
||||
pattern = Call(sigmoid_softmax_pattern, [x])
|
||||
target = Call(P.ReLU(), [x])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_softmax_relu_sigmoid():
|
||||
"""
|
||||
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
|
||||
|
@ -73,11 +90,11 @@ def test_softmax_relu_sigmoid():
|
|||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
pattern = Call(softmax_pattern, inputs=[x])
|
||||
pattern = Call(softmax_pattern, [x])
|
||||
sigmoid_pattern = Prim(P.Sigmoid())
|
||||
call_sigmoid = Call(sigmoid_pattern, [x])
|
||||
relu_pattern = Prim(P.ReLU())
|
||||
target = Call(relu_pattern, inputs=[call_sigmoid])
|
||||
target = Call(relu_pattern, [call_sigmoid])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||
|
@ -98,13 +115,13 @@ def test_isin_pattern_0():
|
|||
def softmax_relu_pass():
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
call_softmax = Call(softmax_pattern, inputs=[x])
|
||||
call_softmax = Call(softmax_pattern, [x])
|
||||
relu_pattern = Prim(P.ReLU())
|
||||
call_relu = Call(relu_pattern, inputs=[x])
|
||||
call_relu = Call(relu_pattern, [x])
|
||||
|
||||
pattern = OneOf([call_softmax, call_relu])
|
||||
relu6_pattern = Prim(P.ReLU6())
|
||||
target = Call(relu6_pattern, inputs=[x])
|
||||
target = Call(relu6_pattern, [x])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
|
@ -122,13 +139,13 @@ def test_isin_pattern_1():
|
|||
def softmax_neg_pass():
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
call_softmax = Call(softmax_pattern, inputs=[x])
|
||||
call_softmax = Call(softmax_pattern, [x])
|
||||
relu_pattern = Prim(P.ReLU())
|
||||
call_relu = Call(relu_pattern, inputs=[x])
|
||||
call_relu = Call(relu_pattern, [x])
|
||||
|
||||
pattern = OneOf([call_softmax, call_relu])
|
||||
neg_ops = Prim(P.Neg())
|
||||
target = Call(neg_ops, inputs=[pattern])
|
||||
target = Call(neg_ops, [pattern])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||
unregiste_pass(softmax_neg_pass)
|
||||
|
@ -141,6 +158,7 @@ def test_isnot_pattern_0():
|
|||
Case: IsNot pass failed to match
|
||||
"""
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
class ConvBN(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ConvBN, self).__init__()
|
||||
|
@ -166,8 +184,8 @@ def test_isnot_pattern_0():
|
|||
conv2d_prim = Prim("Conv2D")
|
||||
conv2d = Call(conv2d_prim)
|
||||
pattern_0 = NoneOf(conv2d)
|
||||
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
|
||||
target = Call(P.ReLU6(), inputs=[pattern_0])
|
||||
pattern = Call(P.BatchNorm(), [pattern_0])
|
||||
target = Call(P.ReLU6(), [pattern_0])
|
||||
return pattern, target
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
|
@ -202,9 +220,9 @@ def test_isnot_pattern_1():
|
|||
matmul = Prim("MatMul")
|
||||
pattern_0 = NoneOf(matmul)
|
||||
softmax = P.Softmax()
|
||||
pattern = Call(softmax, inputs=[pattern_0])
|
||||
pattern = Call(softmax, [pattern_0])
|
||||
relu6 = P.ReLU6()
|
||||
target = Call(relu6, inputs=[pattern_0])
|
||||
target = Call(relu6, [pattern_0])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
|
@ -217,17 +235,18 @@ def test_newtensor_pattern():
|
|||
Test NewTensor pattern in the target
|
||||
"""
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
||||
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
||||
new_weight = NewTensor(weight_tensor)
|
||||
target = Call(P.AddN(), inputs=[x, new_weight])
|
||||
target = Call(P.AddN(), [x, new_weight])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
|
@ -242,17 +261,19 @@ def test_newparameter_pattern():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
|
||||
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para_0 = NewParameter("Merlin", default_tensor0)
|
||||
new_para_1 = NewParameter("Arthur", default_tensor1)
|
||||
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
|
||||
target = Call("make_tuple", inputs=[target_0])
|
||||
target_0 = Call(P.MatMul(), [new_para_0, new_para_1])
|
||||
target = Call("make_tuple", [target_0])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
|
@ -267,13 +288,15 @@ def test_imm_target():
|
|||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
pattern = Call(P.Softmax(), [x])
|
||||
imm = Imm(0)
|
||||
target_0 = Call("make_tuple", inputs=[pattern])
|
||||
target = Call("tuple_getitem", inputs=[target_0, imm])
|
||||
target_0 = Call("make_tuple", [pattern])
|
||||
target = Call("tuple_getitem", [target_0, imm])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
unregiste_pass(softmax_pass)
|
||||
|
@ -290,14 +313,16 @@ def test_gen_new_parameter():
|
|||
|
||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para = NewParameter("Merlin", default_tensor)
|
||||
set_renorm(False)
|
||||
set_reopt(False)
|
||||
gen_new_parameter(new_para)
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_make_tuple_pass():
|
||||
x = Any()
|
||||
softmax = P.Softmax()
|
||||
pattern = Call(softmax, inputs=[x])
|
||||
pattern = Call(softmax, [x])
|
||||
|
||||
target = Call("make_tuple", inputs=[pattern, new_para])
|
||||
target = Call("make_tuple", [pattern, new_para])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
assert "Merlin" in transformed_repr
|
||||
|
|
Loading…
Reference in New Issue