diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index af306effe0c..1f68eefb5c4 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -515,8 +515,8 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector // Perform a conditional jump using switch operation. // The first CNode select graph with condition, and than execute this graph -void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call, - const AnfNodePtr &false_block_call) { +CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call, + const AnfNodePtr &false_block_call) { MS_EXCEPTION_IF_NULL(true_block_call); MS_EXCEPTION_IF_NULL(false_block_call); if (func_graph_->get_return() != nullptr) { @@ -527,13 +527,14 @@ void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePt func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, true_block_call, false_block_call}); CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app}); func_graph_->set_output(switch_app_new); + return switch_app_new; } -void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block, - const FunctionBlockPtr &false_block) { +CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block) { MS_EXCEPTION_IF_NULL(true_block); MS_EXCEPTION_IF_NULL(false_block); - ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())); + return ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph())); } // Create cnode for the assign statement like 'self.target = source'. diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index d19139b637f..1826333b07f 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -63,10 +63,10 @@ class FunctionBlock : public std::enable_shared_from_this { CNodePtr ForceToWhileCond(const AnfNodePtr &cond); void Jump(const FunctionBlockPtr &block, const std::vector &args); AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); - void ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call, - const AnfNodePtr &false_block_call); - void ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block, - const FunctionBlockPtr &false_block); + CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call, + const AnfNodePtr &false_block_call); + CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block); // Create cnode for the assign statement like self.target = source. void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source); void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 23ddaa0c53b..3f99d1e7de5 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -131,9 +131,9 @@ void Parser::CleanParserResource() { void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr &ast) { // Check whether the functions referred by this function and itself are missing 'return' statement - auto mng = Manage(fn, false); + auto manager = Manage(fn, false); MS_EXCEPTION_IF_NULL(ast); - for (const auto &func_graph : mng->func_graphs()) { + for (const auto &func_graph : manager->func_graphs()) { MS_EXCEPTION_IF_NULL(func_graph); if (func_graph->get_return() != nullptr) { continue; @@ -151,37 +151,43 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr> GetFreeVariable(const FuncGraphPtr &func_graph) { + // Considering the performance, we didn't use Manager here. + std::vector> free_variables; + std::vector nodes = + TopoSort(func_graph->get_return(), SuccIncoming, [&func_graph](const AnfNodePtr &node) -> IncludeType { + MS_EXCEPTION_IF_NULL(node); + // Not follow FV's inputs. + if (node->func_graph() != nullptr && node->func_graph() != func_graph) { + return NOFOLLOW; + } + return FOLLOW; + }); + for (auto &node : nodes) { + // Only check Non-FV CNode. + auto cnode = dyn_cast(node); + if (cnode == nullptr || cnode->func_graph() != func_graph) { + continue; + } + + for (size_t i = 0; i < cnode->inputs().size(); ++i) { + auto &input = cnode->input(i); + if (input->func_graph() != nullptr && input->func_graph() != func_graph) { + (void)free_variables.emplace_back(std::make_pair(cnode, i)); + constexpr auto recur_2 = 2; + MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2); + } + } + } + return free_variables; +} + void Parser::LiftRolledBodyGraphFV() { for (auto &rolled_call_pair : rolled_body_calls_) { auto rolled_call_cnode = rolled_call_pair.first; auto rolled_graph = rolled_call_pair.second->func_graph(); MS_EXCEPTION_IF_NULL(rolled_graph); - std::vector> free_variables; - std::vector nodes = - TopoSort(rolled_graph->get_return(), SuccIncoming, [&rolled_graph](const AnfNodePtr &node) -> IncludeType { - MS_EXCEPTION_IF_NULL(node); - // Not follow FV's inputs. - if (node->func_graph() != nullptr && node->func_graph() != rolled_graph) { - return NOFOLLOW; - } - return FOLLOW; - }); - for (auto &node : nodes) { - // Only check Non-FV CNode. - auto cnode = dyn_cast(node); - if (cnode == nullptr || cnode->func_graph() != rolled_graph) { - continue; - } - - for (size_t i = 0; i < cnode->inputs().size(); ++i) { - auto &input = cnode->input(i); - if (input->func_graph() != nullptr && input->func_graph() != rolled_graph) { - (void)free_variables.emplace_back(std::pair(cnode, i)); - constexpr auto recur_2 = 2; - MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2); - } - } - } + const auto &free_variables = GetFreeVariable(rolled_graph); for (auto &free_node_pair : free_variables) { auto &cnode = free_node_pair.first; auto index = free_node_pair.second; @@ -197,6 +203,46 @@ void Parser::LiftRolledBodyGraphFV() { } } +void Parser::LiftIfBranchGraphFV() { + for (auto &branch_call_tuple : if_branch_calls_) { + auto call_cnode = std::get<0>(branch_call_tuple); + auto true_branch_graph = std::get<1>(branch_call_tuple)->func_graph(); + auto false_branch_graph = std::get<2>(branch_call_tuple)->func_graph(); + const auto &true_free_variables = GetFreeVariable(true_branch_graph); + const auto &false_free_variables = GetFreeVariable(false_branch_graph); + // Handle true branch. + for (auto &free_node_pair : true_free_variables) { + auto &cnode = free_node_pair.first; + auto index = free_node_pair.second; + // Move the free variable to parent. + auto &free_node = cnode->input(index); + call_cnode->add_input(free_node); + // Change the free variable to the parameter. + auto parameter = true_branch_graph->add_parameter(); + cnode->set_input(index, parameter); + // Add a unused parameter in other branch. + (void)false_branch_graph->add_parameter(); + constexpr auto recur_2 = 2; + MS_LOG(DEBUG) << "True branch, change FV: " << cnode->DebugString(recur_2); + } + // Handle false branch. + for (auto &free_node_pair : false_free_variables) { + auto &cnode = free_node_pair.first; + auto index = free_node_pair.second; + // Move the free variable to parent. + auto &free_node = cnode->input(index); + call_cnode->add_input(free_node); + // Change the free variable to the parameter. + auto parameter = false_branch_graph->add_parameter(); + cnode->set_input(index, parameter); + // Add a unused parameter in other branch. + (void)true_branch_graph->add_parameter(); + constexpr auto recur_2 = 2; + MS_LOG(DEBUG) << "False branch, change FV: " << cnode->DebugString(recur_2); + } + } +} + namespace { void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph, size_t middle_graph_output_cnode_size, bool use_arguments_pack) { @@ -322,6 +368,8 @@ void Parser::TransformParallelCall() { << ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}"; } + // Lift inner, then lift outer. + LiftIfBranchGraphFV(); LiftRolledBodyGraphFV(); } @@ -1600,7 +1648,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString() << ", false_end: " << false_end->ToString(); } - block->ConditionalJump(bool_node, true_block, false_block); + auto switch_app = block->ConditionalJump(bool_node, true_block, false_block); // Record the former, middle, latter graphs info. if (true_end->func_graph()->get_return() != nullptr || false_end->func_graph()->get_return() != nullptr) { @@ -1622,6 +1670,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object << ", middle: " << false_branch_graphs.second->func_graph()->ToString() << "}"; } + static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1"); + if (transform_for_half_unroll_call) { + // Lift the if branches in for statement. + if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block)); + } + if (after_block->prev_blocks().empty()) { after_block->SetAsDeadBlock(); } @@ -1887,7 +1941,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py: << ", middle: " << loop_graphs.second->func_graph()->ToString() << "}"; // Record the rolled body function, for later lifting operation. if (rolled_body_call != nullptr) { - (void)rolled_body_calls_.emplace_back(std::pair(rolled_body_call, rolled_body_block)); + (void)rolled_body_calls_.emplace_back(std::make_pair(rolled_body_call, rolled_body_block)); constexpr int recursive_level = 2; MS_LOG(DEBUG) << "Record rolled body call: {CNode: " << rolled_body_call->DebugString(recursive_level) << ", rolled_graph: " << rolled_body_block->ToString() << "}"; @@ -2509,13 +2563,13 @@ void Parser::RemoveUnnecessaryPhis() { if (removable_phis.empty()) { return; } - auto mng = Manage(func_graph_, false); + auto manager = Manage(func_graph_, false); // Replace the nodes // Remove from inside to outside for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) { auto phi = phis[LongToSize(idx)]; auto new_node = FindPhis(removable_phis, phi); - mng->Replace(phi, new_node); + manager->Replace(phi, new_node); } // Remove the parameter for (FunctionBlockPtr &block : func_block_list_) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 14568e5feaf..dae433dc906 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -208,6 +209,7 @@ class Parser { // Transform tail call to parallel call. void TransformParallelCall(); void LiftRolledBodyGraphFV(); + void LiftIfBranchGraphFV(); // If Tensor is present as type, not Tensor(xxx), should not make InterpretNode. bool IsTensorType(const AnfNodePtr &node, const std::string &script_text) const; @@ -332,6 +334,8 @@ class Parser { // The func graphs to transform tail call ir to independent call ir. // Contains: {former_graph, middle_graph}, latter_graph is no need. std::vector> parallel_call_graphs_; + // The true branch and false branch call info. of if statement. + std::vector> if_branch_calls_; // The rolled_body callers info. for later lifting operation. std::vector> rolled_body_calls_; // Add exception for if parallel transform. diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 2d87f256e5b..06822a11eb7 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -1042,7 +1042,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectornode()->func_graph(); for (auto eval : evaluators) { MS_EXCEPTION_IF_NULL(eval); - (void)SetUndeterminedFlag(eval, possible_parent_fg); + SetUndeterminedFlag(eval, possible_parent_fg); const auto current_inf = EvaluatorArgs(eval, args_spec_list); MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. diff --git a/tests/st/control/test_for_half_unroll.py b/tests/st/control/test_for_half_unroll.py new file mode 100644 index 00000000000..a53bd7b4d8d --- /dev/null +++ b/tests/st/control/test_for_half_unroll.py @@ -0,0 +1,69 @@ +import os +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell + +context.set_context(mode=context.GRAPH_MODE) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_for_half_unroll_basic(): + """ + Feature: Half unroll compile optimization for for statement. + Description: Only test for statement. + Expectation: Correct result and no exception. + """ + class ForLoopBasic(Cell): + def __init__(self): + super().__init__() + self.array = (Tensor(np.array(10).astype(np.int32)), Tensor(np.array(5).astype(np.int32))) + + def construct(self, x): + output = x + for i in self.array: + output += i + + return output + + net = ForLoopBasic() + x = Tensor(np.array(10).astype(np.int32)) + os.environ['MS_DEV_FOR_HALF_UNROLL'] = '1' + res = net(x) + os.environ['MS_DEV_FOR_HALF_UNROLL'] = '' + assert res == 25 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_for_half_unroll_if(): + """ + Feature: Half unroll compile optimization for for statement. + Description: Test for-in statements. + Expectation: Correct result and no exception. + """ + class ForLoopIf(Cell): + def __init__(self): + super().__init__() + self.array = (Tensor(np.array(10).astype(np.int32)), Tensor(np.array(5).astype(np.int32))) + + def construct(self, x): + output = x + for i in self.array: + if i < 10: + output += i + + return output + + net = ForLoopIf() + x = Tensor(np.array(10).astype(np.int32)) + os.environ['MS_DEV_FOR_HALF_UNROLL'] = '1' + res = net(x) + os.environ['MS_DEV_FOR_HALF_UNROLL'] = '' + assert res == 15