forked from mindspore-Ecosystem/mindspore
Remove original ParseFor implementation.
This commit is contained in:
parent
791adb1dc8
commit
5d15cc9e22
|
@ -533,7 +533,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
|
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
|
||||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
|
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
|
||||||
if (!transform_tail_call_to_parallel_call) {
|
if (!transform_tail_call_to_parallel_call) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,6 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
||||||
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
||||||
|
|
||||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
||||||
max_for_loop_count_str_ = common::GetEnv("MS_DEV_FOR_TO_WHILE_LOOP");
|
|
||||||
support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||||
errcode_ = PARSE_SUCCESS;
|
errcode_ = PARSE_SUCCESS;
|
||||||
BuildMethodMap();
|
BuildMethodMap();
|
||||||
|
@ -80,7 +79,7 @@ void Parser::BuildMethodMap() {
|
||||||
stmt_method_map_["If"] = &Parser::ParseIf;
|
stmt_method_map_["If"] = &Parser::ParseIf;
|
||||||
stmt_method_map_["Assign"] = &Parser::ParseAssign;
|
stmt_method_map_["Assign"] = &Parser::ParseAssign;
|
||||||
stmt_method_map_["While"] = &Parser::ParseWhile;
|
stmt_method_map_["While"] = &Parser::ParseWhile;
|
||||||
stmt_method_map_["For"] = &Parser::ParseForUnroll;
|
stmt_method_map_["For"] = &Parser::ParseFor;
|
||||||
stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
|
stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
|
||||||
stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
|
stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
|
||||||
stmt_method_map_["Global"] = &Parser::ParseGlobal;
|
stmt_method_map_["Global"] = &Parser::ParseGlobal;
|
||||||
|
@ -1514,7 +1513,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
||||||
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
|
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
|
||||||
(void)ignored_if_latter_call_graphs_.insert(after_block);
|
(void)ignored_if_latter_call_graphs_.insert(after_block);
|
||||||
}
|
}
|
||||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
|
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
|
||||||
if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
|
if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
|
||||||
false_branch_graphs.second != nullptr) {
|
false_branch_graphs.second != nullptr) {
|
||||||
true_branch_graphs.first = block;
|
true_branch_graphs.first = block;
|
||||||
|
@ -1610,263 +1609,16 @@ FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Parser::GetForTransToWhileLoop() {
|
|
||||||
// int64 support 63bits positive num mostly.
|
|
||||||
constexpr auto max_num_length = 10;
|
|
||||||
if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
|
|
||||||
return MAX_FOR_LOOP_COUNT;
|
|
||||||
}
|
|
||||||
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
|
|
||||||
[](char c) { return c < '0' || c > '9'; })) {
|
|
||||||
return MAX_FOR_LOOP_COUNT;
|
|
||||||
}
|
|
||||||
int64_t loop_count;
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << max_for_loop_count_str_;
|
|
||||||
ss >> loop_count;
|
|
||||||
return loop_count;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A for loop will generate 3 functions :the test, the body, and the continuation
|
|
||||||
// for x in xs:
|
|
||||||
// body
|
|
||||||
// It is compiled to be following statement
|
|
||||||
// if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop
|
|
||||||
// else, ParseForLoop. Use loop var to implement for loop, which always sink loop
|
|
||||||
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
||||||
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
|
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
|
||||||
// Create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
|
|
||||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
|
||||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
|
|
||||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
|
||||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
|
||||||
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
|
|
||||||
CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
|
|
||||||
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
|
|
||||||
|
|
||||||
// Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
|
|
||||||
FunctionBlockPtr true_block = nullptr;
|
|
||||||
FunctionBlockPtr false_block = nullptr;
|
|
||||||
{
|
|
||||||
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
|
|
||||||
true_block = MakeFunctionBlock(*this);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
|
|
||||||
false_block = MakeFunctionBlock(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
MakeConditionBlocks(block, true_block, false_block);
|
|
||||||
|
|
||||||
FunctionBlockPtr after_block = nullptr;
|
|
||||||
{
|
|
||||||
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
|
|
||||||
after_block = MakeFunctionBlock(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
FunctionBlockPtr true_end = ParseForIter(true_block, node);
|
|
||||||
true_end->Jump(after_block, {});
|
|
||||||
|
|
||||||
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
|
|
||||||
false_end->Jump(after_block, {});
|
|
||||||
|
|
||||||
block->ConditionalJump(bool_node, true_block, false_block);
|
|
||||||
after_block->Mature();
|
|
||||||
return after_block;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A for loop will generate 3 functions: the test, the body, and the continuation.
|
|
||||||
// for x in xs:
|
|
||||||
// body
|
|
||||||
// It is compiled to be following statement:
|
|
||||||
// it = iter(xs)
|
|
||||||
// while hastnext(it)
|
|
||||||
// x, it = next(it)
|
|
||||||
// body
|
|
||||||
FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
|
|
||||||
MS_LOG(DEBUG) << "Process ast For";
|
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
|
||||||
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
|
||||||
AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
|
|
||||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
|
||||||
AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
|
|
||||||
// Generate the iterator apply
|
|
||||||
CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
|
|
||||||
MS_EXCEPTION_IF_NULL(iter_apply);
|
|
||||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
|
||||||
MS_EXCEPTION_IF_NULL(header_block);
|
|
||||||
MS_EXCEPTION_IF_NULL(header_block->func_graph());
|
|
||||||
// Generate the hasnext apply which is a condition
|
|
||||||
ParameterPtr iter_param = header_block->func_graph()->add_parameter();
|
|
||||||
CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
|
|
||||||
// Generate the body of the for statement
|
|
||||||
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
|
||||||
MS_EXCEPTION_IF_NULL(body_block);
|
|
||||||
body_block->AddPrevBlock(header_block);
|
|
||||||
MS_EXCEPTION_IF_NULL(body_block->func_graph());
|
|
||||||
// Generate the iterator next apply
|
|
||||||
// Process as following: `app = next(it); target = app[0]; it = app[1];`
|
|
||||||
CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
|
|
||||||
CNodePtr target_app =
|
|
||||||
body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
|
|
||||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
|
||||||
|
|
||||||
CNodePtr iter2_app =
|
|
||||||
body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
|
|
||||||
WriteAssignVars(body_block, target_node, target_app);
|
|
||||||
|
|
||||||
// Link the variable name with the target
|
|
||||||
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
|
|
||||||
MS_EXCEPTION_IF_NULL(iter_param->debug_info());
|
|
||||||
MS_EXCEPTION_IF_NULL(iter2_app->debug_info());
|
|
||||||
MS_EXCEPTION_IF_NULL(iter_apply->debug_info());
|
|
||||||
iter_param->debug_info()->set_trace_info(it_info);
|
|
||||||
iter2_app->debug_info()->set_trace_info(it_info);
|
|
||||||
iter_apply->debug_info()->set_trace_info(it_info);
|
|
||||||
|
|
||||||
FunctionBlockPtr after_block = nullptr;
|
|
||||||
{
|
|
||||||
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
|
|
||||||
after_block = MakeFunctionBlock(*this);
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(after_block);
|
|
||||||
after_block->AddPrevBlock(header_block);
|
|
||||||
|
|
||||||
block->Jump(header_block, {iter_apply});
|
|
||||||
body_block->Mature();
|
|
||||||
header_block->ConditionalJump(cond_apply, body_block, after_block);
|
|
||||||
|
|
||||||
// Parse loop body statements with loop context.
|
|
||||||
LoopContext loop_context{&loops_, header_block, iter2_app};
|
|
||||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
|
||||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(after_body_block->func_graph());
|
|
||||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
|
||||||
after_body_block->Jump(header_block, {iter2_app});
|
|
||||||
}
|
|
||||||
|
|
||||||
header_block->Mature();
|
|
||||||
after_block->Mature();
|
|
||||||
auto &end_block = loop_context.EndBlock();
|
|
||||||
if (end_block) {
|
|
||||||
// end_block exists if we encounter 'break' in loop body.
|
|
||||||
after_block->Jump(end_block, {});
|
|
||||||
end_block->Mature();
|
|
||||||
return end_block;
|
|
||||||
}
|
|
||||||
// No 'break', no end_block.
|
|
||||||
return after_block;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A for loop will generate 3 functions: the test, the body, and the continuation.
|
|
||||||
// for x in xs:
|
|
||||||
// body
|
|
||||||
// It is compiled to be following statement:
|
|
||||||
// i = 0
|
|
||||||
// while i < len(xs)
|
|
||||||
// x = xs[i]
|
|
||||||
// i = i + 1
|
|
||||||
// body
|
|
||||||
FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
|
|
||||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
|
||||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
|
||||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
|
||||||
|
|
||||||
// Get variable name of 'x' in statement 'for x in xs'
|
|
||||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
|
||||||
|
|
||||||
// Create statement 'len(xs)'
|
|
||||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
|
|
||||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
|
||||||
MS_EXCEPTION_IF_NULL(iter_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
|
||||||
// Generate node for loop count and convert it to tensor, to make the loop not unroll
|
|
||||||
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
|
|
||||||
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
|
|
||||||
auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
|
|
||||||
|
|
||||||
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
|
|
||||||
|
|
||||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
|
||||||
MS_EXCEPTION_IF_NULL(header_block);
|
|
||||||
MS_EXCEPTION_IF_NULL(header_block->func_graph());
|
|
||||||
// Create loop variable 'i'
|
|
||||||
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
|
|
||||||
// Create loop condition 'i < len(xs)'
|
|
||||||
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
|
|
||||||
auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
|
|
||||||
CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
|
|
||||||
|
|
||||||
// Generate the body of the for statement
|
|
||||||
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
|
||||||
MS_EXCEPTION_IF_NULL(body_block);
|
|
||||||
body_block->AddPrevBlock(header_block);
|
|
||||||
// Create 'x = xs[i]'
|
|
||||||
auto body_func_graph = body_block->func_graph();
|
|
||||||
MS_EXCEPTION_IF_NULL(body_func_graph);
|
|
||||||
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
|
|
||||||
WriteAssignVars(body_block, target_node, target_var);
|
|
||||||
// Create 'i = i + 1'
|
|
||||||
auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
|
||||||
auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
|
|
||||||
auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
|
|
||||||
auto add_tensor_node =
|
|
||||||
body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
|
|
||||||
CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
|
|
||||||
body_block->WriteVariable(loop_var->name(), loop_var_inc);
|
|
||||||
|
|
||||||
// Link the variable name with the target
|
|
||||||
auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
|
|
||||||
MS_EXCEPTION_IF_NULL(loop_var->debug_info());
|
|
||||||
MS_EXCEPTION_IF_NULL(len_iter->debug_info());
|
|
||||||
loop_var->debug_info()->set_trace_info(it_info);
|
|
||||||
len_iter->debug_info()->set_trace_info(it_info);
|
|
||||||
|
|
||||||
FunctionBlockPtr after_block = nullptr;
|
|
||||||
{
|
|
||||||
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
|
|
||||||
after_block = MakeFunctionBlock(*this);
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(after_block);
|
|
||||||
after_block->AddPrevBlock(header_block);
|
|
||||||
|
|
||||||
CNodePtr zero_tensor =
|
|
||||||
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
|
|
||||||
block->Jump(header_block, {zero_tensor});
|
|
||||||
body_block->Mature();
|
|
||||||
|
|
||||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
|
||||||
|
|
||||||
// Parse loop body statements with loop context.
|
|
||||||
LoopContext loop_context{&loops_, header_block, loop_var_inc};
|
|
||||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
|
||||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(after_body_block->func_graph());
|
|
||||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
|
||||||
after_body_block->Jump(header_block, {loop_var_inc});
|
|
||||||
}
|
|
||||||
|
|
||||||
header_block->Mature();
|
|
||||||
after_block->Mature();
|
|
||||||
auto &end_block = loop_context.EndBlock();
|
|
||||||
if (end_block) {
|
|
||||||
// end_block exists if we encounter 'break' in loop body.
|
|
||||||
after_block->Jump(end_block, {});
|
|
||||||
end_block->Mature();
|
|
||||||
return end_block;
|
|
||||||
}
|
|
||||||
// No 'break', no end_block.
|
|
||||||
return after_block;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement unroll for statement with tuple/getitem.
|
|
||||||
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
|
|
||||||
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
||||||
if (transform_for_half_unroll_call) {
|
if (transform_for_half_unroll_call) {
|
||||||
return ParseForRepeat(block, node);
|
return ParseForRepeat(block, node);
|
||||||
}
|
}
|
||||||
|
return ParseForUnroll(block, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement unroll for statement with tuple/getitem.
|
||||||
|
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
|
||||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||||
|
@ -1943,6 +1695,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
||||||
return after_block;
|
return after_block;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Implement for statement with repeat calling sub graph.
|
||||||
FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) {
|
FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) {
|
||||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
|
|
|
@ -126,8 +126,6 @@ class Parser {
|
||||||
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
|
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
|
||||||
// Process a for statement
|
// Process a for statement
|
||||||
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
|
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
|
||||||
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
|
|
||||||
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
|
|
||||||
FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node);
|
FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node);
|
||||||
FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node);
|
FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node);
|
||||||
// Process a function def statement
|
// Process a function def statement
|
||||||
|
@ -299,7 +297,6 @@ class Parser {
|
||||||
}
|
}
|
||||||
// Return a make tuple for input elements list
|
// Return a make tuple for input elements list
|
||||||
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
|
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
|
||||||
int64_t GetForTransToWhileLoop();
|
|
||||||
|
|
||||||
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||||
static FuncGraphWeakPtr top_func_graph_;
|
static FuncGraphWeakPtr top_func_graph_;
|
||||||
|
@ -321,7 +318,6 @@ class Parser {
|
||||||
std::map<std::string, ExprFunc> expr_method_map_;
|
std::map<std::string, ExprFunc> expr_method_map_;
|
||||||
// Save current loops to support 'continue', 'break' statement.
|
// Save current loops to support 'continue', 'break' statement.
|
||||||
std::stack<Loop> loops_;
|
std::stack<Loop> loops_;
|
||||||
string max_for_loop_count_str_;
|
|
||||||
string support_fallback_;
|
string support_fallback_;
|
||||||
|
|
||||||
// The func graphs to transform tail call ir to independent call ir.
|
// The func graphs to transform tail call ir to independent call ir.
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import os
|
|
||||||
import pytest
|
import pytest
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor, nn
|
from mindspore import Tensor, nn
|
||||||
|
@ -21,8 +20,8 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
grad_all = C.GradOperation(get_all=True)
|
grad_all = C.GradOperation(get_all=True)
|
||||||
context.set_context(device_target="Ascend")
|
|
||||||
|
|
||||||
|
# Although we don't transform for to while any more, we keep this test case.
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@ -53,7 +52,6 @@ def test_single_for_01():
|
||||||
y = Tensor([5], mstype.int32)
|
y = Tensor([5], mstype.int32)
|
||||||
z = Tensor([4], mstype.int32)
|
z = Tensor([4], mstype.int32)
|
||||||
|
|
||||||
os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = '1'
|
|
||||||
# graph mode
|
# graph mode
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
for_net = SingleForNet()
|
for_net = SingleForNet()
|
||||||
|
@ -67,7 +65,6 @@ def test_single_for_01():
|
||||||
net = GradNet(for_net)
|
net = GradNet(for_net)
|
||||||
pynative_forward_res = for_net(x, y, z)
|
pynative_forward_res = for_net(x, y, z)
|
||||||
pynative_backward_res = net(x, y, z)
|
pynative_backward_res = net(x, y, z)
|
||||||
os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = ''
|
|
||||||
|
|
||||||
assert graph_forward_res == pynative_forward_res
|
assert graph_forward_res == pynative_forward_res
|
||||||
assert graph_backward_res == pynative_backward_res
|
assert graph_backward_res == pynative_backward_res
|
||||||
|
|
Loading…
Reference in New Issue