!31694 Remove original ParseFor implementation.

Merge pull request !31694 from 张清华/opt_for_unroll
This commit is contained in:
i-robot 2022-03-22 23:42:42 +00:00 committed by Gitee
commit 263edeadbb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 10 additions and 264 deletions

View File

@ -533,7 +533,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
}
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
if (!transform_tail_call_to_parallel_call) {
return true;
}

View File

@ -68,7 +68,6 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
max_for_loop_count_str_ = common::GetEnv("MS_DEV_FOR_TO_WHILE_LOOP");
support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
errcode_ = PARSE_SUCCESS;
BuildMethodMap();
@ -80,7 +79,7 @@ void Parser::BuildMethodMap() {
stmt_method_map_["If"] = &Parser::ParseIf;
stmt_method_map_["Assign"] = &Parser::ParseAssign;
stmt_method_map_["While"] = &Parser::ParseWhile;
stmt_method_map_["For"] = &Parser::ParseForUnroll;
stmt_method_map_["For"] = &Parser::ParseFor;
stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
stmt_method_map_["Global"] = &Parser::ParseGlobal;
@ -1513,7 +1512,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
(void)ignored_if_latter_call_graphs_.insert(after_block);
}
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
false_branch_graphs.second != nullptr) {
true_branch_graphs.first = block;
@ -1609,263 +1608,16 @@ FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
return block;
}
int64_t Parser::GetForTransToWhileLoop() {
// int64 support 63bits positive num mostly.
constexpr auto max_num_length = 10;
if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
return MAX_FOR_LOOP_COUNT;
}
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
[](char c) { return c < '0' || c > '9'; })) {
return MAX_FOR_LOOP_COUNT;
}
int64_t loop_count;
std::stringstream ss;
ss << max_for_loop_count_str_;
ss >> loop_count;
return loop_count;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
// It is compiled to be following statement
// if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop
// else, ParseForLoop. Use loop var to implement for loop, which always sink loop
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
MS_EXCEPTION_IF_NULL(block);
// Create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
MS_EXCEPTION_IF_NULL(block->func_graph());
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
// Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
false_block = MakeFunctionBlock(*this);
}
MakeConditionBlocks(block, true_block, false_block);
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
FunctionBlockPtr true_end = ParseForIter(true_block, node);
true_end->Jump(after_block, {});
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
false_end->Jump(after_block, {});
block->ConditionalJump(bool_node, true_block, false_block);
after_block->Mature();
return after_block;
}
// A for loop will generate 3 functions: the test, the body, and the continuation.
// for x in xs:
// body
// It is compiled to be following statement:
// it = iter(xs)
// while hastnext(it)
// x, it = next(it)
// body
FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For";
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
// Generate the iterator apply
CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
MS_EXCEPTION_IF_NULL(iter_apply);
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(header_block);
MS_EXCEPTION_IF_NULL(header_block->func_graph());
// Generate the hasnext apply which is a condition
ParameterPtr iter_param = header_block->func_graph()->add_parameter();
CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
// Generate the body of the for statement
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(body_block);
body_block->AddPrevBlock(header_block);
MS_EXCEPTION_IF_NULL(body_block->func_graph());
// Generate the iterator next apply
// Process as following: `app = next(it); target = app[0]; it = app[1];`
CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
CNodePtr target_app =
body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
CNodePtr iter2_app =
body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
WriteAssignVars(body_block, target_node, target_app);
// Link the variable name with the target
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
MS_EXCEPTION_IF_NULL(iter_param->debug_info());
MS_EXCEPTION_IF_NULL(iter2_app->debug_info());
MS_EXCEPTION_IF_NULL(iter_apply->debug_info());
iter_param->debug_info()->set_trace_info(it_info);
iter2_app->debug_info()->set_trace_info(it_info);
iter_apply->debug_info()->set_trace_info(it_info);
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
MS_EXCEPTION_IF_NULL(after_block);
after_block->AddPrevBlock(header_block);
block->Jump(header_block, {iter_apply});
body_block->Mature();
header_block->ConditionalJump(cond_apply, body_block, after_block);
// Parse loop body statements with loop context.
LoopContext loop_context{&loops_, header_block, iter2_app};
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
MS_EXCEPTION_IF_NULL(after_body_block->func_graph());
if (after_body_block->func_graph()->get_return() == nullptr) {
after_body_block->Jump(header_block, {iter2_app});
}
header_block->Mature();
after_block->Mature();
auto &end_block = loop_context.EndBlock();
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, {});
end_block->Mature();
return end_block;
}
// No 'break', no end_block.
return after_block;
}
// A for loop will generate 3 functions: the test, the body, and the continuation.
// for x in xs:
// body
// It is compiled to be following statement:
// i = 0
// while i < len(xs)
// x = xs[i]
// i = i + 1
// body
FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For by loop variable";
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
// Get variable name of 'x' in statement 'for x in xs'
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
// Create statement 'len(xs)'
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
MS_EXCEPTION_IF_NULL(iter_node);
MS_EXCEPTION_IF_NULL(block->func_graph());
// Generate node for loop count and convert it to tensor, to make the loop not unroll
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(header_block);
MS_EXCEPTION_IF_NULL(header_block->func_graph());
// Create loop variable 'i'
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
// Create loop condition 'i < len(xs)'
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
// Generate the body of the for statement
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(body_block);
body_block->AddPrevBlock(header_block);
// Create 'x = xs[i]'
auto body_func_graph = body_block->func_graph();
MS_EXCEPTION_IF_NULL(body_func_graph);
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
WriteAssignVars(body_block, target_node, target_var);
// Create 'i = i + 1'
auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
auto add_tensor_node =
body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
body_block->WriteVariable(loop_var->name(), loop_var_inc);
// Link the variable name with the target
auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
MS_EXCEPTION_IF_NULL(loop_var->debug_info());
MS_EXCEPTION_IF_NULL(len_iter->debug_info());
loop_var->debug_info()->set_trace_info(it_info);
len_iter->debug_info()->set_trace_info(it_info);
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
after_block = MakeFunctionBlock(*this);
}
MS_EXCEPTION_IF_NULL(after_block);
after_block->AddPrevBlock(header_block);
CNodePtr zero_tensor =
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
block->Jump(header_block, {zero_tensor});
body_block->Mature();
header_block->ConditionalJump(cond_node, body_block, after_block);
// Parse loop body statements with loop context.
LoopContext loop_context{&loops_, header_block, loop_var_inc};
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
MS_EXCEPTION_IF_NULL(after_body_block->func_graph());
if (after_body_block->func_graph()->get_return() == nullptr) {
after_body_block->Jump(header_block, {loop_var_inc});
}
header_block->Mature();
after_block->Mature();
auto &end_block = loop_context.EndBlock();
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, {});
end_block->Mature();
return end_block;
}
// No 'break', no end_block.
return after_block;
}
// Implement unroll for statement with tuple/getitem.
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
if (transform_for_half_unroll_call) {
return ParseForRepeat(block, node);
}
return ParseForUnroll(block, node);
}
// Implement unroll for statement with tuple/getitem.
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For by loop variable";
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
@ -1942,6 +1694,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
return after_block;
}
// Implement for statement with repeat calling sub graph.
FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For by loop variable";
MS_EXCEPTION_IF_NULL(block);

View File

@ -126,8 +126,6 @@ class Parser {
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
// Process a for statement
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node);
// Process a function def statement
@ -299,7 +297,6 @@ class Parser {
}
// Return a make tuple for input elements list
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
int64_t GetForTransToWhileLoop();
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
static FuncGraphWeakPtr top_func_graph_;
@ -321,7 +318,6 @@ class Parser {
std::map<std::string, ExprFunc> expr_method_map_;
// Save current loops to support 'continue', 'break' statement.
std::stack<Loop> loops_;
string max_for_loop_count_str_;
string support_fallback_;
// The func graphs to transform tail call ir to independent call ir.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
from mindspore import context
from mindspore import Tensor, nn
@ -21,8 +20,8 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
# Although we don't transform for to while any more, we keep this test case.
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -53,7 +52,6 @@ def test_single_for_01():
y = Tensor([5], mstype.int32)
z = Tensor([4], mstype.int32)
os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = '1'
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_net = SingleForNet()
@ -67,7 +65,6 @@ def test_single_for_01():
net = GradNet(for_net)
pynative_forward_res = for_net(x, y, z)
pynative_backward_res = net(x, y, z)
os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = ''
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res