forked from mindspore-Ecosystem/mindspore
!2954 Revert "Optimization for ApplyTransform function"
Merge pull request !2954 from Kang/optimization
This commit is contained in:
commit
4bdd8e16a2
|
@ -92,18 +92,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
|
|||
return result;
|
||||
}
|
||||
|
||||
static bool inline isTraversable(const AnfNodePtr &node, const AnfNodeSet &all_nodes) {
|
||||
if (node->isa<CNode>() || node->isa<Parameter>()) {
|
||||
static bool isTraversable(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
|
||||
if (!all_nodes.contains(node)) {
|
||||
return false;
|
||||
}
|
||||
if (node->isa<CNode>() || node->isa<Parameter>()) {
|
||||
return true;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -126,15 +124,9 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
|
|||
todo.pop_front();
|
||||
|
||||
// check whether this node has been matched.
|
||||
if (node == nullptr || node->seen_ == seen) {
|
||||
if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto fg = node->func_graph();
|
||||
if (!(fg != nullptr && fg->manager() != nullptr) && !isTraversable(node, all_nodes)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
node->seen_ = seen;
|
||||
|
||||
// select nodes that this transform can be applied.
|
||||
|
|
|
@ -1152,7 +1152,6 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
|
||||
// get varibale name of 'x' in statement 'for x in xs'
|
||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(target_node, "id"));
|
||||
|
||||
// create statement 'len(xs)'
|
||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
|
||||
|
@ -1174,13 +1173,11 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
body_block->AddPrevBlock(header_block);
|
||||
// create 'x = xs[i]'
|
||||
CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var});
|
||||
target_var->debug_info()->set_name(name_id);
|
||||
body_block->WriteVariable(name_id, target_var);
|
||||
WriteAssignVars(body_block, target_node, target_var);
|
||||
// create 'i = i + 1'
|
||||
CNodePtr loop_var_inc =
|
||||
body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)});
|
||||
body_block->WriteVariable(loop_var->name(), loop_var_inc);
|
||||
loop_var_inc->debug_info()->set_name(name_id);
|
||||
|
||||
// link the variable name with the target
|
||||
auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
#include "./common.h"
|
||||
|
@ -48,10 +47,9 @@ class PyFuncGraphFetcher {
|
|||
py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...);
|
||||
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
|
||||
if (doResolve_) {
|
||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, true);
|
||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
|
||||
mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
|
||||
mindspore::parse::ResolveAll(manager);
|
||||
func_graph = BasicClone(func_graph);
|
||||
}
|
||||
return func_graph;
|
||||
} catch (py::error_already_set& e) {
|
||||
|
@ -73,9 +71,8 @@ class PyFuncGraphFetcher {
|
|||
py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str());
|
||||
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
|
||||
if (doResolve_) {
|
||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, true);
|
||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
|
||||
mindspore::parse::ResolveAll(manager);
|
||||
func_graph = BasicClone(func_graph);
|
||||
}
|
||||
return func_graph;
|
||||
} catch (py::error_already_set& e) {
|
||||
|
|
Loading…
Reference in New Issue