forked from mindspore-Ecosystem/mindspore
!31694 Remove original ParseFor implementation.
Merge pull request !31694 from 张清华/opt_for_unroll
This commit is contained in:
commit
263edeadbb
|
@ -533,7 +533,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
|
||||
if (!transform_tail_call_to_parallel_call) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -68,7 +68,6 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
|||
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
||||
|
||||
Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
|
||||
max_for_loop_count_str_ = common::GetEnv("MS_DEV_FOR_TO_WHILE_LOOP");
|
||||
support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
errcode_ = PARSE_SUCCESS;
|
||||
BuildMethodMap();
|
||||
|
@ -80,7 +79,7 @@ void Parser::BuildMethodMap() {
|
|||
stmt_method_map_["If"] = &Parser::ParseIf;
|
||||
stmt_method_map_["Assign"] = &Parser::ParseAssign;
|
||||
stmt_method_map_["While"] = &Parser::ParseWhile;
|
||||
stmt_method_map_["For"] = &Parser::ParseForUnroll;
|
||||
stmt_method_map_["For"] = &Parser::ParseFor;
|
||||
stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
|
||||
stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
|
||||
stmt_method_map_["Global"] = &Parser::ParseGlobal;
|
||||
|
@ -1513,7 +1512,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
|
||||
(void)ignored_if_latter_call_graphs_.insert(after_block);
|
||||
}
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
|
||||
if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
|
||||
false_branch_graphs.second != nullptr) {
|
||||
true_branch_graphs.first = block;
|
||||
|
@ -1609,263 +1608,16 @@ FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
|
|||
return block;
|
||||
}
|
||||
|
||||
int64_t Parser::GetForTransToWhileLoop() {
|
||||
// int64 support 63bits positive num mostly.
|
||||
constexpr auto max_num_length = 10;
|
||||
if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
|
||||
return MAX_FOR_LOOP_COUNT;
|
||||
}
|
||||
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
|
||||
[](char c) { return c < '0' || c > '9'; })) {
|
||||
return MAX_FOR_LOOP_COUNT;
|
||||
}
|
||||
int64_t loop_count;
|
||||
std::stringstream ss;
|
||||
ss << max_for_loop_count_str_;
|
||||
ss >> loop_count;
|
||||
return loop_count;
|
||||
}
|
||||
|
||||
// A for loop will generate 3 functions :the test, the body, and the continuation
|
||||
// for x in xs:
|
||||
// body
|
||||
// It is compiled to be following statement
|
||||
// if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop
|
||||
// else, ParseForLoop. Use loop var to implement for loop, which always sink loop
|
||||
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// Create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
|
||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
|
||||
CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
|
||||
|
||||
// Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
|
||||
FunctionBlockPtr true_block = nullptr;
|
||||
FunctionBlockPtr false_block = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
|
||||
true_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
|
||||
false_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
|
||||
MakeConditionBlocks(block, true_block, false_block);
|
||||
|
||||
FunctionBlockPtr after_block = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
|
||||
after_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
|
||||
FunctionBlockPtr true_end = ParseForIter(true_block, node);
|
||||
true_end->Jump(after_block, {});
|
||||
|
||||
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
|
||||
false_end->Jump(after_block, {});
|
||||
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
after_block->Mature();
|
||||
return after_block;
|
||||
}
|
||||
|
||||
// A for loop will generate 3 functions: the test, the body, and the continuation.
|
||||
// for x in xs:
|
||||
// body
|
||||
// It is compiled to be following statement:
|
||||
// it = iter(xs)
|
||||
// while hastnext(it)
|
||||
// x, it = next(it)
|
||||
// body
|
||||
FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
||||
AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
|
||||
// Generate the iterator apply
|
||||
CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
|
||||
MS_EXCEPTION_IF_NULL(iter_apply);
|
||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
MS_EXCEPTION_IF_NULL(header_block->func_graph());
|
||||
// Generate the hasnext apply which is a condition
|
||||
ParameterPtr iter_param = header_block->func_graph()->add_parameter();
|
||||
CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
|
||||
// Generate the body of the for statement
|
||||
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(body_block);
|
||||
body_block->AddPrevBlock(header_block);
|
||||
MS_EXCEPTION_IF_NULL(body_block->func_graph());
|
||||
// Generate the iterator next apply
|
||||
// Process as following: `app = next(it); target = app[0]; it = app[1];`
|
||||
CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
|
||||
CNodePtr target_app =
|
||||
body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
|
||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
||||
|
||||
CNodePtr iter2_app =
|
||||
body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
|
||||
WriteAssignVars(body_block, target_node, target_app);
|
||||
|
||||
// Link the variable name with the target
|
||||
auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(iter_param->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(iter2_app->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(iter_apply->debug_info());
|
||||
iter_param->debug_info()->set_trace_info(it_info);
|
||||
iter2_app->debug_info()->set_trace_info(it_info);
|
||||
iter_apply->debug_info()->set_trace_info(it_info);
|
||||
|
||||
FunctionBlockPtr after_block = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
|
||||
after_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(after_block);
|
||||
after_block->AddPrevBlock(header_block);
|
||||
|
||||
block->Jump(header_block, {iter_apply});
|
||||
body_block->Mature();
|
||||
header_block->ConditionalJump(cond_apply, body_block, after_block);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, iter2_app};
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
MS_EXCEPTION_IF_NULL(after_body_block->func_graph());
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, {iter2_app});
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
after_block->Mature();
|
||||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, {});
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
// No 'break', no end_block.
|
||||
return after_block;
|
||||
}
|
||||
|
||||
// A for loop will generate 3 functions: the test, the body, and the continuation.
|
||||
// for x in xs:
|
||||
// body
|
||||
// It is compiled to be following statement:
|
||||
// i = 0
|
||||
// while i < len(xs)
|
||||
// x = xs[i]
|
||||
// i = i + 1
|
||||
// body
|
||||
FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
|
||||
// Get variable name of 'x' in statement 'for x in xs'
|
||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
||||
|
||||
// Create statement 'len(xs)'
|
||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
|
||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
||||
MS_EXCEPTION_IF_NULL(iter_node);
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
// Generate node for loop count and convert it to tensor, to make the loop not unroll
|
||||
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
|
||||
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
|
||||
auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
|
||||
|
||||
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
|
||||
|
||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
MS_EXCEPTION_IF_NULL(header_block->func_graph());
|
||||
// Create loop variable 'i'
|
||||
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
|
||||
// Create loop condition 'i < len(xs)'
|
||||
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
|
||||
auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
|
||||
CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
|
||||
|
||||
// Generate the body of the for statement
|
||||
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(body_block);
|
||||
body_block->AddPrevBlock(header_block);
|
||||
// Create 'x = xs[i]'
|
||||
auto body_func_graph = body_block->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(body_func_graph);
|
||||
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
|
||||
WriteAssignVars(body_block, target_node, target_var);
|
||||
// Create 'i = i + 1'
|
||||
auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
||||
auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
|
||||
auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
|
||||
auto add_tensor_node =
|
||||
body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
|
||||
CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
|
||||
body_block->WriteVariable(loop_var->name(), loop_var_inc);
|
||||
|
||||
// Link the variable name with the target
|
||||
auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(loop_var->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(len_iter->debug_info());
|
||||
loop_var->debug_info()->set_trace_info(it_info);
|
||||
len_iter->debug_info()->set_trace_info(it_info);
|
||||
|
||||
FunctionBlockPtr after_block = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
|
||||
after_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(after_block);
|
||||
after_block->AddPrevBlock(header_block);
|
||||
|
||||
CNodePtr zero_tensor =
|
||||
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
|
||||
block->Jump(header_block, {zero_tensor});
|
||||
body_block->Mature();
|
||||
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, loop_var_inc};
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
MS_EXCEPTION_IF_NULL(after_body_block->func_graph());
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, {loop_var_inc});
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
after_block->Mature();
|
||||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, {});
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
// No 'break', no end_block.
|
||||
return after_block;
|
||||
}
|
||||
|
||||
// Implement unroll for statement with tuple/getitem.
|
||||
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
|
||||
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
||||
if (transform_for_half_unroll_call) {
|
||||
return ParseForRepeat(block, node);
|
||||
}
|
||||
return ParseForUnroll(block, node);
|
||||
}
|
||||
|
||||
// Implement unroll for statement with tuple/getitem.
|
||||
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
|
@ -1942,6 +1694,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
return after_block;
|
||||
}
|
||||
|
||||
// Implement for statement with repeat calling sub graph.
|
||||
FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
|
|
|
@ -126,8 +126,6 @@ class Parser {
|
|||
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a for statement
|
||||
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a function def statement
|
||||
|
@ -299,7 +297,6 @@ class Parser {
|
|||
}
|
||||
// Return a make tuple for input elements list
|
||||
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
|
||||
int64_t GetForTransToWhileLoop();
|
||||
|
||||
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||
static FuncGraphWeakPtr top_func_graph_;
|
||||
|
@ -321,7 +318,6 @@ class Parser {
|
|||
std::map<std::string, ExprFunc> expr_method_map_;
|
||||
// Save current loops to support 'continue', 'break' statement.
|
||||
std::stack<Loop> loops_;
|
||||
string max_for_loop_count_str_;
|
||||
string support_fallback_;
|
||||
|
||||
// The func graphs to transform tail call ir to independent call ir.
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
|
@ -21,8 +20,8 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common import dtype as mstype
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
# Although we don't transform for to while any more, we keep this test case.
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -53,7 +52,6 @@ def test_single_for_01():
|
|||
y = Tensor([5], mstype.int32)
|
||||
z = Tensor([4], mstype.int32)
|
||||
|
||||
os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = '1'
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_net = SingleForNet()
|
||||
|
@ -67,7 +65,6 @@ def test_single_for_01():
|
|||
net = GradNet(for_net)
|
||||
pynative_forward_res = for_net(x, y, z)
|
||||
pynative_backward_res = net(x, y, z)
|
||||
os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = ''
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
|
Loading…
Reference in New Issue