From 4b2b46679aced35569fb6e40bf312cf9f7129b5f Mon Sep 17 00:00:00 2001 From: wuyongkang Date: Thu, 9 Jul 2020 10:41:51 +0800 Subject: [PATCH 1/2] Revert "Optimization for ApplyTransform function" This reverts commit 02dd305bb030d993e754a056c159288b5943d90c. --- mindspore/ccsrc/optimizer/opt.cc | 22 +++++++-------------- tests/ut/cpp/common/py_func_graph_fetcher.h | 7 ++----- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index e6addae76e8..5e893cf1aa7 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -92,18 +92,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode return result; } -static bool inline isTraversable(const AnfNodePtr &node, const AnfNodeSet &all_nodes) { - if (node->isa() || node->isa()) { +static bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { return false; } - - if (IsValueNode(node) || IsValueNode(node)) { - if (!all_nodes.contains(node)) { - return false; - } + if (node->isa() || node->isa()) { + return true; + } + if (IsValueNode(node) || IsValueNode(node)) { return true; } - return false; } @@ -126,15 +124,9 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo todo.pop_front(); // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen) { + if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; } - - auto fg = node->func_graph(); - if (!(fg != nullptr && fg->manager() != nullptr) && !isTraversable(node, all_nodes)) { - continue; - } - node->seen_ = seen; // select nodes that this transform can be applied. diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index 9d374fcd601..98552a96b54 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -22,7 +22,6 @@ #include "ir/primitive.h" #include "ir/manager.h" #include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse.h" #include "./common.h" @@ -48,10 +47,9 @@ class PyFuncGraphFetcher { py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { - std::shared_ptr manager = mindspore::Manage(func_graph, true); + std::shared_ptr manager = mindspore::Manage(func_graph, false); mindspore::parse::python_adapter::set_use_signature_in_resolve(false); mindspore::parse::ResolveAll(manager); - func_graph = BasicClone(func_graph); } return func_graph; } catch (py::error_already_set& e) { @@ -73,9 +71,8 @@ class PyFuncGraphFetcher { py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str()); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { - std::shared_ptr manager = mindspore::Manage(func_graph, true); + std::shared_ptr manager = mindspore::Manage(func_graph, false); mindspore::parse::ResolveAll(manager); - func_graph = BasicClone(func_graph); } return func_graph; } catch (py::error_already_set& e) { From 41229ed01dad1b99a27f55ee96430dffbd38ca66 Mon Sep 17 00:00:00 2001 From: wuyongkang Date: Thu, 9 Jul 2020 14:51:33 +0800 Subject: [PATCH 2/2] Fix bug of for i, j in enumerate(items) --- mindspore/ccsrc/pipeline/parse/parse.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index 351a83124eb..1d306d9ca4c 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -1152,7 +1152,6 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o // get varibale name of 'x' in statement 'for x in xs' py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - auto name_id = py::cast(python_adapter::GetPyObjAttr(target_node, "id")); // create statement 'len(xs)' py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); @@ -1174,13 +1173,11 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o body_block->AddPrevBlock(header_block); // create 'x = xs[i]' CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var}); - target_var->debug_info()->set_name(name_id); - body_block->WriteVariable(name_id, target_var); + WriteAssignVars(body_block, target_node, target_var); // create 'i = i + 1' CNodePtr loop_var_inc = body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)}); body_block->WriteVariable(loop_var->name(), loop_var_inc); - loop_var_inc->debug_info()->set_name(name_id); // link the variable name with the target auto it_info = std::make_shared(loop_var_inc->debug_info());