diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.h b/mindspore/ccsrc/frontend/optimizer/pattern.h index c0bdc6cda2..473a4c0520 100644 --- a/mindspore/ccsrc/frontend/optimizer/pattern.h +++ b/mindspore/ccsrc/frontend/optimizer/pattern.h @@ -59,6 +59,7 @@ class Pattern : public Base { string unique_name() const { return unique_name_; } vector inputs() { return inputs_; } virtual void reset() {} + static void reset_gid() { g_id_ = 0; } protected: static int g_id_; @@ -213,7 +214,6 @@ class NewParameter : public Pattern { explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel) : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; - // clone input tensor default_tensor_ = std::make_shared(*default_tensor.get()); built_ = false; } @@ -257,7 +257,7 @@ class MatchResult { MatchResult() {} ~MatchResult() = default; void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } - PatternNodeMap _result() { return match_result_; } + PatternNodeMap &_result() { return match_result_; } AnfNodePtr get_node(const PatternPtr &pattern); void merge(const MatchResultPtr &other_result); void clear() { match_result_.clear(); } diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index 95a90b44d6..877da5f5a3 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -27,8 +27,6 @@ #include "pipeline/jit/resource.h" #include "frontend/optimizer/py_pass_manager.h" #include "utils/info.h" -#include "debug/anf_ir_dump.h" -#include "debug/draw.h" namespace mindspore { namespace opt { @@ -42,29 +40,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, bool requires_grad, bool layerwise_parallel); -std::string GetNodeRepr(AnfNodePtr node) { - if (node != nullptr) { - if (node->isa()) { - std::string repr = "("; - auto const &inputs = node->cast()->inputs(); - for (auto &input : inputs) { - repr += " "; - repr += GetNodeRepr(input); - repr += " "; - } - repr += ")"; - return repr; - } - if (node->isa()) { - return "[Parameter]" + node->ToString(); - } else if (node->isa()) { - return "[Value]" + GetValueNode(node)->ToString(); - } - return node->ToString(); - } - return ""; -} - bool IsTraversable(const AnfNodePtr &node) { if (node == nullptr) { return false; @@ -215,23 +190,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph return new_node; } -void DrawNode(string name, AnfNodePtr node) { - auto context_ptr = MsContext::GetInstance(); - bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); - auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - auto new_func_graph = std::make_shared(); - new_func_graph->set_output(node, true); - if (save_graphs) { - auto ir_dump_path = save_graphs_path + "/" + name + ".ir"; - auto dot_dump_path = save_graphs_path + "/" + name + ".dot"; - DumpIR(ir_dump_path, new_func_graph); - draw::Draw(dot_dump_path, new_func_graph); - } -} - void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, bool requires_grad, bool layerwise_parallel) { // 1. Get current cell object @@ -241,12 +199,15 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor if (py::isinstance(top_cell)) { MS_LOG(EXCEPTION) << "Failed to get top cell from resource."; } - // 2. New a Parameter object with the above-specified args + // 2. Clone default_input tensor + auto default_tensor = std::make_shared(default_input->data_type(), default_input->shape_c(), + default_input->data_c(), (size_t)default_input->Size()); + // 3. New a Parameter object with the above-specified args py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS); - py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel); - // 3. Add the new python Parameter object to Cell's _params atttributes + py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel); + // 4. Add the new python Parameter object to Cell's _params atttributes top_cell.attr(SET_PARAM)(param_name, new_parameter); - // 4. Set default_param for param_node + // 5. Set default_param for param_node ValuePtr param_value = nullptr; bool converted = parse::ConvertData(new_parameter, ¶m_value, false); if (!converted) { @@ -282,11 +243,9 @@ void Reset(PatternPtr pattern) { AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { auto match_res = src_pattern_->match(node); if (match_res != nullptr) { - MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node); res->merge(match_res); auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); internal::Reset(dst_pattern()); - MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; return new_node; } internal::Reset(src_pattern()); @@ -303,7 +262,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; } auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name(); - MS_LOG(DEBUG) << "Adding New parameter : " + para_name; auto para_node = std::make_shared(func_graph); MS_EXCEPTION_IF_NULL(para_node); para_node->set_name(para_name); @@ -321,7 +279,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) // Reflect back to Cell._params internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), new_para_pattern->layerwise_parallel()); - MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); + MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); return true; } FuncGraphManagerPtr manager = func_graph->manager(); @@ -334,7 +292,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) for (auto &node : graph_nodes_sorted) { AnfNodePtr new_node = Run(func_graph, node, res); if (new_node != nullptr && new_node != node) { - internal::DrawNode(dst_pattern_->unique_name(), new_node); (void)manager->Replace(node, new_node); changes = true; } diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc index 9c8ef59f02..4540d5bbca 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -98,7 +98,8 @@ REGISTER_PYBIND_DEFINE( .def("registe", &PyPassManager::Registe, "Registe python pass") .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") - .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph"); + .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph") + .def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph"); })); } // namespace python_pass } // namespace opt diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h index c5f35800cc..c892d46855 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -60,13 +60,19 @@ class PyPassManager { MatchResultPtr GetMatchResult() { return res_; } void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } bool ShouldRenorm() { return should_renorm_; } + void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; } + bool ShouldReOpt() { return should_reopt_; } void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } pipeline::ResourcePtr GetResource() { return resource_; } void ClearRes(); - void ClearPipelineRes() { resource_ = nullptr; } + void ClearPipelineRes() { + resource_ = nullptr; + Pattern::reset_gid(); + } private: bool should_renorm_ = true; + bool should_reopt_ = true; MatchResultPtr res_; pipeline::ResourcePtr resource_; static std::unordered_map phase_to_group_; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 7470367f62..9b1c893851 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -451,35 +451,55 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } -void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { +bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { MS_EXCEPTION_IF_NULL(res->manager()); MS_EXCEPTION_IF_NULL(res->func_graph()); auto ppm = opt::python_pass::PyPassManager::GetInstance(); ppm->SetResource(res); - if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { - MS_LOG(DEBUG) << "No match.\n"; - } else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { - MS_LOG(DEBUG) << "Entered PyStub Renorm"; - // Renomalize - MS_EXCEPTION_IF_NULL(res->func_graph()); - FuncGraphPtr func_graph = res->func_graph(); - abstract::AbstractBasePtrList args_spec; - auto parameters = func_graph->parameters(); - (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), - [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); - FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); - res->set_func_graph(new_fg); - res->set_args_spec(args_spec); - } + return ppm->GetPassGroup(phase)->Run(res->func_graph()); } -bool ResolveActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::RESOLVE); +bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); } + +bool OptActionVmPyStub(const ResourcePtr &res) { + if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { + if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { + // Renomalize + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + res->set_args_spec(args_spec); + } + if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) { + return VmOptimizeAction(res); + } + } return true; } -bool OptActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::OPT); +bool OptActionGePyStub(const ResourcePtr &res) { + if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { + if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { + // Renomalize + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + res->set_args_spec(args_spec); + } + if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) { + return GeOptimizeAction(res); + } + } return true; } @@ -510,7 +530,7 @@ std::vector GePipeline() { // optimize actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); // Add opt-stage python pass stub - actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); actions.emplace_back(std::make_pair("validate", ValidateAction)); return actions; @@ -523,7 +543,7 @@ std::vector VmPipeline() { actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); // Add opt-stage python pass stub - actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub)); actions.emplace_back(std::make_pair("validate", ValidateAction)); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) diff --git a/mindspore/graph_utils/python_pass/__init__.py b/mindspore/graph_utils/python_pass/__init__.py index 26a8216295..d9fe61c870 100644 --- a/mindspore/graph_utils/python_pass/__init__.py +++ b/mindspore/graph_utils/python_pass/__init__.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Top-level reference to python pass.""" -from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm +"""Reference for python pass registration.""" +from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ + set_reopt __all__ = [ "registe_pass", "unregiste_pass", "gen_new_parameter", "cancel_new_parameter", - "set_renorm" + "set_renorm", + "set_reopt" ] diff --git a/mindspore/graph_utils/python_pass/python_pass_register.py b/mindspore/graph_utils/python_pass/python_pass_register.py index bde94fef99..8e37c44c78 100644 --- a/mindspore/graph_utils/python_pass/python_pass_register.py +++ b/mindspore/graph_utils/python_pass/python_pass_register.py @@ -23,7 +23,8 @@ __all__ = [ "unregiste_pass", "gen_new_parameter", "cancel_new_parameter", - "set_renorm" + "set_renorm", + "set_reopt" ] class PyPassManager(PyPassManager_): r""" @@ -75,6 +76,11 @@ class PyPassManager(PyPassManager_): raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") super().set_renorm(should_renorm) + def set_reopt(self, do_reopt): + if not isinstance(do_reopt, bool): + raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") + super().set_reopt(do_reopt) + def registe_pass(run_only_once=False): """ Registe python pass to specified pipeline phase which would be used in compilation. @@ -164,3 +170,17 @@ def set_renorm(should_renorm): """ ppm = PyPassManager() ppm.set_renorm(should_renorm) + +def set_reopt(do_reopt): + """ + Set whether or not to do optimization after modified graph in python pass(es). + + Args: + do_reopt(bool): whether or not to do optimization after modified graph in python pass(es). + + NOTE: + This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off + renormalization may BREAK the network. + """ + ppm = PyPassManager() + ppm.set_reopt(do_reopt) diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py index 40e7ecd2da..9038229fca 100644 --- a/tests/ut/python/optimizer/test_python_pass.py +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -20,7 +20,7 @@ from mindspore import context from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ - cancel_new_parameter + cancel_new_parameter, set_reopt from mindspore.common.api import _generate_pip_args from mindspore._c_expression import generate_key, Executor_ from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm @@ -50,8 +50,8 @@ def test_softmax_relu(): @registe_pass(run_only_once=True) def softmax_relu_pass(): x = Any() - pattern = Call(P.Softmax(), inputs=[x]) - target = Call(P.ReLU(), inputs=[x]) + pattern = Call(P.Softmax(), [x]) + target = Call(P.ReLU(), [x]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) @@ -59,6 +59,23 @@ def test_softmax_relu(): assert "ReLU" in transformed_repr assert "Softmax" not in transformed_repr +def test_prim(): + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_relu_pass(): + x = Any() + sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()]) + pattern = Call(sigmoid_softmax_pattern, [x]) + target = Call(P.ReLU(), [x]) + return pattern, target + + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) + unregiste_pass(softmax_relu_pass) + assert "ReLU" in transformed_repr + assert "Softmax" not in transformed_repr + def test_softmax_relu_sigmoid(): """ Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). @@ -73,11 +90,11 @@ def test_softmax_relu_sigmoid(): def softmax_relu_pass(): x = Any() softmax_pattern = Prim(P.Softmax()) - pattern = Call(softmax_pattern, inputs=[x]) + pattern = Call(softmax_pattern, [x]) sigmoid_pattern = Prim(P.Sigmoid()) call_sigmoid = Call(sigmoid_pattern, [x]) relu_pattern = Prim(P.ReLU()) - target = Call(relu_pattern, inputs=[call_sigmoid]) + target = Call(relu_pattern, [call_sigmoid]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) @@ -98,13 +115,13 @@ def test_isin_pattern_0(): def softmax_relu_pass(): x = Any() softmax_pattern = Prim(P.Softmax()) - call_softmax = Call(softmax_pattern, inputs=[x]) + call_softmax = Call(softmax_pattern, [x]) relu_pattern = Prim(P.ReLU()) - call_relu = Call(relu_pattern, inputs=[x]) + call_relu = Call(relu_pattern, [x]) pattern = OneOf([call_softmax, call_relu]) relu6_pattern = Prim(P.ReLU6()) - target = Call(relu6_pattern, inputs=[x]) + target = Call(relu6_pattern, [x]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) unregiste_pass(softmax_relu_pass) @@ -122,13 +139,13 @@ def test_isin_pattern_1(): def softmax_neg_pass(): x = Any() softmax_pattern = Prim(P.Softmax()) - call_softmax = Call(softmax_pattern, inputs=[x]) + call_softmax = Call(softmax_pattern, [x]) relu_pattern = Prim(P.ReLU()) - call_relu = Call(relu_pattern, inputs=[x]) + call_relu = Call(relu_pattern, [x]) pattern = OneOf([call_softmax, call_relu]) neg_ops = Prim(P.Neg()) - target = Call(neg_ops, inputs=[pattern]) + target = Call(neg_ops, [pattern]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) unregiste_pass(softmax_neg_pass) @@ -141,6 +158,7 @@ def test_isnot_pattern_0(): Case: IsNot pass failed to match """ set_renorm(False) + set_reopt(False) class ConvBN(nn.Cell): def __init__(self): super(ConvBN, self).__init__() @@ -166,8 +184,8 @@ def test_isnot_pattern_0(): conv2d_prim = Prim("Conv2D") conv2d = Call(conv2d_prim) pattern_0 = NoneOf(conv2d) - pattern = Call(P.BatchNorm(), inputs=[pattern_0]) - target = Call(P.ReLU6(), inputs=[pattern_0]) + pattern = Call(P.BatchNorm(), [pattern_0]) + target = Call(P.ReLU6(), [pattern_0]) return pattern, target @registe_pass(run_only_once=True) @@ -202,9 +220,9 @@ def test_isnot_pattern_1(): matmul = Prim("MatMul") pattern_0 = NoneOf(matmul) softmax = P.Softmax() - pattern = Call(softmax, inputs=[pattern_0]) + pattern = Call(softmax, [pattern_0]) relu6 = P.ReLU6() - target = Call(relu6, inputs=[pattern_0]) + target = Call(relu6, [pattern_0]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) @@ -217,17 +235,18 @@ def test_newtensor_pattern(): Test NewTensor pattern in the target """ set_renorm(False) + set_reopt(False) inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() @registe_pass(run_only_once=True) def softmax_addn_pass(): x = Any() - pattern = Call(P.Softmax(), inputs=[x]) + pattern = Call(P.Softmax(), [x]) weight_tensor = Tensor(np.zeros([42]), mindspore.float16) new_weight = NewTensor(weight_tensor) - target = Call(P.AddN(), inputs=[x, new_weight]) + target = Call(P.AddN(), [x, new_weight]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) unregiste_pass(softmax_addn_pass) @@ -242,17 +261,19 @@ def test_newparameter_pattern(): inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() + set_renorm(False) + set_reopt(False) @registe_pass(run_only_once=True) def softmax_addn_pass(): x = Any() - pattern = Call(P.Softmax(), inputs=[x]) + pattern = Call(P.Softmax(), [x]) default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) new_para_0 = NewParameter("Merlin", default_tensor0) new_para_1 = NewParameter("Arthur", default_tensor1) - target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1]) - target = Call("make_tuple", inputs=[target_0]) + target_0 = Call(P.MatMul(), [new_para_0, new_para_1]) + target = Call("make_tuple", [target_0]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) unregiste_pass(softmax_addn_pass) @@ -267,13 +288,15 @@ def test_imm_target(): inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() + set_renorm(False) + set_reopt(False) @registe_pass(run_only_once=True) def softmax_pass(): x = Any() - pattern = Call(P.Softmax(), inputs=[x]) + pattern = Call(P.Softmax(), [x]) imm = Imm(0) - target_0 = Call("make_tuple", inputs=[pattern]) - target = Call("tuple_getitem", inputs=[target_0, imm]) + target_0 = Call("make_tuple", [pattern]) + target = Call("tuple_getitem", [target_0, imm]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) unregiste_pass(softmax_pass) @@ -290,14 +313,16 @@ def test_gen_new_parameter(): default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) new_para = NewParameter("Merlin", default_tensor) + set_renorm(False) + set_reopt(False) gen_new_parameter(new_para) @registe_pass(run_only_once=True) def softmax_make_tuple_pass(): x = Any() softmax = P.Softmax() - pattern = Call(softmax, inputs=[x]) + pattern = Call(softmax, [x]) - target = Call("make_tuple", inputs=[pattern, new_para]) + target = Call("make_tuple", [pattern, new_para]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) assert "Merlin" in transformed_repr