From 5d15cc9e226589b3c42b97cad7ded1884474c1b8 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Tue, 22 Mar 2022 11:41:03 +0800 Subject: [PATCH] Remove original ParseFor implementation. --- mindspore/ccsrc/pipeline/jit/action.cc | 2 +- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 263 +------------------- mindspore/ccsrc/pipeline/jit/parse/parse.h | 4 - tests/st/control/test_for_to_while.py | 5 +- 4 files changed, 10 insertions(+), 264 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 40b3cc8b3dd..0e3aec612ac 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -533,7 +533,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) { } bool EliminateUnusedParameterAction(const ResourcePtr &res) { - static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1"); + static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1"); if (!transform_tail_call_to_parallel_call) { return true; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index f8ec9375d8c..629fea338b3 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -68,7 +68,6 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { - max_for_loop_count_str_ = common::GetEnv("MS_DEV_FOR_TO_WHILE_LOOP"); support_fallback_ = common::GetEnv("MS_DEV_ENABLE_FALLBACK"); errcode_ = PARSE_SUCCESS; BuildMethodMap(); @@ -80,7 +79,7 @@ void Parser::BuildMethodMap() { stmt_method_map_["If"] = &Parser::ParseIf; stmt_method_map_["Assign"] = &Parser::ParseAssign; stmt_method_map_["While"] = &Parser::ParseWhile; - stmt_method_map_["For"] = &Parser::ParseForUnroll; + stmt_method_map_["For"] = &Parser::ParseFor; stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; stmt_method_map_["Global"] = &Parser::ParseGlobal; @@ -1514,7 +1513,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object << ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString(); (void)ignored_if_latter_call_graphs_.insert(after_block); } - static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1"); + static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1"); if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr && false_branch_graphs.second != nullptr) { true_branch_graphs.first = block; @@ -1610,263 +1609,16 @@ FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) { return block; } -int64_t Parser::GetForTransToWhileLoop() { - // int64 support 63bits positive num mostly. - constexpr auto max_num_length = 10; - if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) { - return MAX_FOR_LOOP_COUNT; - } - if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(), - [](char c) { return c < '0' || c > '9'; })) { - return MAX_FOR_LOOP_COUNT; - } - int64_t loop_count; - std::stringstream ss; - ss << max_for_loop_count_str_; - ss >> loop_count; - return loop_count; -} - -// A for loop will generate 3 functions :the test, the body, and the continuation -// for x in xs: -// body -// It is compiled to be following statement -// if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop -// else, ParseForLoop. Use loop var to implement for loop, which always sink loop FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For, create an if else statement"; - MS_EXCEPTION_IF_NULL(block); - // Create statement 'len(xs) < MAX_FOR_LOOP_COUNT' - AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); - py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); - AnfNodePtr iter_node = ParseExprNode(block, iter_obj); - MS_EXCEPTION_IF_NULL(block->func_graph()); - CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node}); - CNodePtr bool_node = block->func_graph()->NewCNodeInOrder( - {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())}); - - // Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' - FunctionBlockPtr true_block = nullptr; - FunctionBlockPtr false_block = nullptr; - { - TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); - true_block = MakeFunctionBlock(*this); - } - { - TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); - false_block = MakeFunctionBlock(*this); - } - - MakeConditionBlocks(block, true_block, false_block); - - FunctionBlockPtr after_block = nullptr; - { - TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); - after_block = MakeFunctionBlock(*this); - } - - FunctionBlockPtr true_end = ParseForIter(true_block, node); - true_end->Jump(after_block, {}); - - FunctionBlockPtr false_end = ParseForLoop(false_block, node); - false_end->Jump(after_block, {}); - - block->ConditionalJump(bool_node, true_block, false_block); - after_block->Mature(); - return after_block; -} - -// A for loop will generate 3 functions: the test, the body, and the continuation. -// for x in xs: -// body -// It is compiled to be following statement: -// it = iter(xs) -// while hastnext(it) -// x, it = next(it) -// body -FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); - AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); - // Generate the iterator apply - CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); - MS_EXCEPTION_IF_NULL(iter_apply); - FunctionBlockPtr header_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(header_block); - MS_EXCEPTION_IF_NULL(header_block->func_graph()); - // Generate the hasnext apply which is a condition - ParameterPtr iter_param = header_block->func_graph()->add_parameter(); - CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); - // Generate the body of the for statement - FunctionBlockPtr body_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(body_block); - body_block->AddPrevBlock(header_block); - MS_EXCEPTION_IF_NULL(body_block->func_graph()); - // Generate the iterator next apply - // Process as following: `app = next(it); target = app[0]; it = app[1];` - CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param}); - CNodePtr target_app = - body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast(0))}); - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - - CNodePtr iter2_app = - body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast(1))}); - WriteAssignVars(body_block, target_node, target_app); - - // Link the variable name with the target - auto it_info = std::make_shared(target_app->debug_info()); - MS_EXCEPTION_IF_NULL(iter_param->debug_info()); - MS_EXCEPTION_IF_NULL(iter2_app->debug_info()); - MS_EXCEPTION_IF_NULL(iter_apply->debug_info()); - iter_param->debug_info()->set_trace_info(it_info); - iter2_app->debug_info()->set_trace_info(it_info); - iter_apply->debug_info()->set_trace_info(it_info); - - FunctionBlockPtr after_block = nullptr; - { - TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); - after_block = MakeFunctionBlock(*this); - } - MS_EXCEPTION_IF_NULL(after_block); - after_block->AddPrevBlock(header_block); - - block->Jump(header_block, {iter_apply}); - body_block->Mature(); - header_block->ConditionalJump(cond_apply, body_block, after_block); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, iter2_app}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); - MS_EXCEPTION_IF_NULL(after_body_block->func_graph()); - if (after_body_block->func_graph()->get_return() == nullptr) { - after_body_block->Jump(header_block, {iter2_app}); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, {}); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} - -// A for loop will generate 3 functions: the test, the body, and the continuation. -// for x in xs: -// body -// It is compiled to be following statement: -// i = 0 -// while i < len(xs) -// x = xs[i] -// i = i + 1 -// body -FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For by loop variable"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - - // Get variable name of 'x' in statement 'for x in xs' - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - - // Create statement 'len(xs)' - py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); - AnfNodePtr iter_node = ParseExprNode(block, iter_obj); - MS_EXCEPTION_IF_NULL(iter_node); - MS_EXCEPTION_IF_NULL(block->func_graph()); - // Generate node for loop count and convert it to tensor, to make the loop not unroll - CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node}); - auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); - auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)}); - - CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len}); - - FunctionBlockPtr header_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(header_block); - MS_EXCEPTION_IF_NULL(header_block->func_graph()); - // Create loop variable 'i' - ParameterPtr loop_var = header_block->func_graph()->add_parameter(); - // Create loop condition 'i < len(xs)' - auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); - auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)}); - CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter}); - - // Generate the body of the for statement - FunctionBlockPtr body_block = GenerateBlock(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(body_block); - body_block->AddPrevBlock(header_block); - // Create 'x = xs[i]' - auto body_func_graph = body_block->func_graph(); - MS_EXCEPTION_IF_NULL(body_func_graph); - CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var}); - WriteAssignVars(body_block, target_node, target_var); - // Create 'i = i + 1' - auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations"); - auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)}); - auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)}); - auto add_tensor_node = - body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast(1))}); - CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node}); - body_block->WriteVariable(loop_var->name(), loop_var_inc); - - // Link the variable name with the target - auto it_info = std::make_shared(loop_var_inc->debug_info()); - MS_EXCEPTION_IF_NULL(loop_var->debug_info()); - MS_EXCEPTION_IF_NULL(len_iter->debug_info()); - loop_var->debug_info()->set_trace_info(it_info); - len_iter->debug_info()->set_trace_info(it_info); - - FunctionBlockPtr after_block = nullptr; - { - TraceGuard guard(std::make_shared(block->func_graph()->debug_info())); - after_block = MakeFunctionBlock(*this); - } - MS_EXCEPTION_IF_NULL(after_block); - after_block->AddPrevBlock(header_block); - - CNodePtr zero_tensor = - block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast(0))}); - block->Jump(header_block, {zero_tensor}); - body_block->Mature(); - - header_block->ConditionalJump(cond_node, body_block, after_block); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, loop_var_inc}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); - MS_EXCEPTION_IF_NULL(after_body_block->func_graph()); - if (after_body_block->func_graph()->get_return() == nullptr) { - after_body_block->Jump(header_block, {loop_var_inc}); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, {}); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} - -// Implement unroll for statement with tuple/getitem. -FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) { static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1"); if (transform_for_half_unroll_call) { return ParseForRepeat(block, node); } + return ParseForUnroll(block, node); +} + +// Implement unroll for statement with tuple/getitem. +FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For by loop variable"; MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); @@ -1943,6 +1695,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py: return after_block; } +// Implement for statement with repeat calling sub graph. FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast For by loop variable"; MS_EXCEPTION_IF_NULL(block); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 04426557b11..1549e140b69 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -126,8 +126,6 @@ class Parser { FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); // Process a for statement FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); - FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node); - FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node); FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node); FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node); // Process a function def statement @@ -299,7 +297,6 @@ class Parser { } // Return a make tuple for input elements list AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes); - int64_t GetForTransToWhileLoop(); // The shared_ptr will be hold by GraphManager, so just hold a weak ref here. static FuncGraphWeakPtr top_func_graph_; @@ -321,7 +318,6 @@ class Parser { std::map expr_method_map_; // Save current loops to support 'continue', 'break' statement. std::stack loops_; - string max_for_loop_count_str_; string support_fallback_; // The func graphs to transform tail call ir to independent call ir. diff --git a/tests/st/control/test_for_to_while.py b/tests/st/control/test_for_to_while.py index feb527e1065..9ad7dbaa580 100644 --- a/tests/st/control/test_for_to_while.py +++ b/tests/st/control/test_for_to_while.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import os import pytest from mindspore import context from mindspore import Tensor, nn @@ -21,8 +20,8 @@ from mindspore.ops import operations as P from mindspore.common import dtype as mstype grad_all = C.GradOperation(get_all=True) -context.set_context(device_target="Ascend") +# Although we don't transform for to while any more, we keep this test case. @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -53,7 +52,6 @@ def test_single_for_01(): y = Tensor([5], mstype.int32) z = Tensor([4], mstype.int32) - os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = '1' # graph mode context.set_context(mode=context.GRAPH_MODE) for_net = SingleForNet() @@ -67,7 +65,6 @@ def test_single_for_01(): net = GradNet(for_net) pynative_forward_res = for_net(x, y, z) pynative_backward_res = net(x, y, z) - os.environ['MS_DEV_FOR_TO_WHILE_LOOP'] = '' assert graph_forward_res == pynative_forward_res assert graph_backward_res == pynative_backward_res