Lift all fv in if true and false branches.

This commit is contained in:
Zhang Qinghua 2022-03-25 15:08:24 +08:00
parent d4346e26d7
commit 4147bb0ad5
6 changed files with 170 additions and 42 deletions

View File

@ -515,8 +515,8 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector
// Perform a conditional jump using switch operation.
// The first CNode select graph with condition, and than execute this graph
void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
const AnfNodePtr &false_block_call) {
CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
const AnfNodePtr &false_block_call) {
MS_EXCEPTION_IF_NULL(true_block_call);
MS_EXCEPTION_IF_NULL(false_block_call);
if (func_graph_->get_return() != nullptr) {
@ -527,13 +527,14 @@ void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePt
func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, true_block_call, false_block_call});
CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
func_graph_->set_output(switch_app_new);
return switch_app_new;
}
void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block) {
CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block) {
MS_EXCEPTION_IF_NULL(true_block);
MS_EXCEPTION_IF_NULL(false_block);
ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph()));
return ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph()));
}
// Create cnode for the assign statement like 'self.target = source'.

View File

@ -63,10 +63,10 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
void ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
const AnfNodePtr &false_block_call);
void ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block);
CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
const AnfNodePtr &false_block_call);
CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block);
// Create cnode for the assign statement like self.target = source.
void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source);
void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); }

View File

@ -131,9 +131,9 @@ void Parser::CleanParserResource() {
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunctionAst> &ast) {
// Check whether the functions referred by this function and itself are missing 'return' statement
auto mng = Manage(fn, false);
auto manager = Manage(fn, false);
MS_EXCEPTION_IF_NULL(ast);
for (const auto &func_graph : mng->func_graphs()) {
for (const auto &func_graph : manager->func_graphs()) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->get_return() != nullptr) {
continue;
@ -151,37 +151,43 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
}
}
std::vector<std::pair<CNodePtr, size_t>> GetFreeVariable(const FuncGraphPtr &func_graph) {
// Considering the performance, we didn't use Manager here.
std::vector<std::pair<CNodePtr, size_t>> free_variables;
std::vector<AnfNodePtr> nodes =
TopoSort(func_graph->get_return(), SuccIncoming, [&func_graph](const AnfNodePtr &node) -> IncludeType {
MS_EXCEPTION_IF_NULL(node);
// Not follow FV's inputs.
if (node->func_graph() != nullptr && node->func_graph() != func_graph) {
return NOFOLLOW;
}
return FOLLOW;
});
for (auto &node : nodes) {
// Only check Non-FV CNode.
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->func_graph() != func_graph) {
continue;
}
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
auto &input = cnode->input(i);
if (input->func_graph() != nullptr && input->func_graph() != func_graph) {
(void)free_variables.emplace_back(std::make_pair(cnode, i));
constexpr auto recur_2 = 2;
MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2);
}
}
}
return free_variables;
}
void Parser::LiftRolledBodyGraphFV() {
for (auto &rolled_call_pair : rolled_body_calls_) {
auto rolled_call_cnode = rolled_call_pair.first;
auto rolled_graph = rolled_call_pair.second->func_graph();
MS_EXCEPTION_IF_NULL(rolled_graph);
std::vector<std::pair<CNodePtr, size_t>> free_variables;
std::vector<AnfNodePtr> nodes =
TopoSort(rolled_graph->get_return(), SuccIncoming, [&rolled_graph](const AnfNodePtr &node) -> IncludeType {
MS_EXCEPTION_IF_NULL(node);
// Not follow FV's inputs.
if (node->func_graph() != nullptr && node->func_graph() != rolled_graph) {
return NOFOLLOW;
}
return FOLLOW;
});
for (auto &node : nodes) {
// Only check Non-FV CNode.
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->func_graph() != rolled_graph) {
continue;
}
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
auto &input = cnode->input(i);
if (input->func_graph() != nullptr && input->func_graph() != rolled_graph) {
(void)free_variables.emplace_back(std::pair(cnode, i));
constexpr auto recur_2 = 2;
MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2);
}
}
}
const auto &free_variables = GetFreeVariable(rolled_graph);
for (auto &free_node_pair : free_variables) {
auto &cnode = free_node_pair.first;
auto index = free_node_pair.second;
@ -197,6 +203,46 @@ void Parser::LiftRolledBodyGraphFV() {
}
}
void Parser::LiftIfBranchGraphFV() {
for (auto &branch_call_tuple : if_branch_calls_) {
auto call_cnode = std::get<0>(branch_call_tuple);
auto true_branch_graph = std::get<1>(branch_call_tuple)->func_graph();
auto false_branch_graph = std::get<2>(branch_call_tuple)->func_graph();
const auto &true_free_variables = GetFreeVariable(true_branch_graph);
const auto &false_free_variables = GetFreeVariable(false_branch_graph);
// Handle true branch.
for (auto &free_node_pair : true_free_variables) {
auto &cnode = free_node_pair.first;
auto index = free_node_pair.second;
// Move the free variable to parent.
auto &free_node = cnode->input(index);
call_cnode->add_input(free_node);
// Change the free variable to the parameter.
auto parameter = true_branch_graph->add_parameter();
cnode->set_input(index, parameter);
// Add a unused parameter in other branch.
(void)false_branch_graph->add_parameter();
constexpr auto recur_2 = 2;
MS_LOG(DEBUG) << "True branch, change FV: " << cnode->DebugString(recur_2);
}
// Handle false branch.
for (auto &free_node_pair : false_free_variables) {
auto &cnode = free_node_pair.first;
auto index = free_node_pair.second;
// Move the free variable to parent.
auto &free_node = cnode->input(index);
call_cnode->add_input(free_node);
// Change the free variable to the parameter.
auto parameter = false_branch_graph->add_parameter();
cnode->set_input(index, parameter);
// Add a unused parameter in other branch.
(void)true_branch_graph->add_parameter();
constexpr auto recur_2 = 2;
MS_LOG(DEBUG) << "False branch, change FV: " << cnode->DebugString(recur_2);
}
}
}
namespace {
void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph,
size_t middle_graph_output_cnode_size, bool use_arguments_pack) {
@ -322,6 +368,8 @@ void Parser::TransformParallelCall() {
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
}
// Lift inner, then lift outer.
LiftIfBranchGraphFV();
LiftRolledBodyGraphFV();
}
@ -1600,7 +1648,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString()
<< ", false_end: " << false_end->ToString();
}
block->ConditionalJump(bool_node, true_block, false_block);
auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
// Record the former, middle, latter graphs info.
if (true_end->func_graph()->get_return() != nullptr || false_end->func_graph()->get_return() != nullptr) {
@ -1622,6 +1670,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
<< ", middle: " << false_branch_graphs.second->func_graph()->ToString() << "}";
}
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
if (transform_for_half_unroll_call) {
// Lift the if branches in for statement.
if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
}
if (after_block->prev_blocks().empty()) {
after_block->SetAsDeadBlock();
}
@ -1887,7 +1941,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
<< ", middle: " << loop_graphs.second->func_graph()->ToString() << "}";
// Record the rolled body function, for later lifting operation.
if (rolled_body_call != nullptr) {
(void)rolled_body_calls_.emplace_back(std::pair(rolled_body_call, rolled_body_block));
(void)rolled_body_calls_.emplace_back(std::make_pair(rolled_body_call, rolled_body_block));
constexpr int recursive_level = 2;
MS_LOG(DEBUG) << "Record rolled body call: {CNode: " << rolled_body_call->DebugString(recursive_level)
<< ", rolled_graph: " << rolled_body_block->ToString() << "}";
@ -2509,13 +2563,13 @@ void Parser::RemoveUnnecessaryPhis() {
if (removable_phis.empty()) {
return;
}
auto mng = Manage(func_graph_, false);
auto manager = Manage(func_graph_, false);
// Replace the nodes
// Remove from inside to outside
for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
auto phi = phis[LongToSize(idx)];
auto new_node = FindPhis(removable_phis, phi);
mng->Replace(phi, new_node);
manager->Replace(phi, new_node);
}
// Remove the parameter
for (FunctionBlockPtr &block : func_block_list_) {

View File

@ -21,6 +21,7 @@
#include <limits>
#include <utility>
#include <tuple>
#include <vector>
#include <string>
#include <map>
@ -208,6 +209,7 @@ class Parser {
// Transform tail call to parallel call.
void TransformParallelCall();
void LiftRolledBodyGraphFV();
void LiftIfBranchGraphFV();
// If Tensor is present as type, not Tensor(xxx), should not make InterpretNode.
bool IsTensorType(const AnfNodePtr &node, const std::string &script_text) const;
@ -332,6 +334,8 @@ class Parser {
// The func graphs to transform tail call ir to independent call ir.
// Contains: {former_graph, middle_graph}, latter_graph is no need.
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> parallel_call_graphs_;
// The true branch and false branch call info. of if statement.
std::vector<std::tuple<CNodePtr, FunctionBlockPtr, FunctionBlockPtr>> if_branch_calls_;
// The rolled_body callers info. for later lifting operation.
std::vector<std::pair<CNodePtr, FunctionBlockPtr>> rolled_body_calls_;
// Add exception for if parallel transform.

View File

@ -1042,7 +1042,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
auto possible_parent_fg = out_conf->node()->func_graph();
for (auto eval : evaluators) {
MS_EXCEPTION_IF_NULL(eval);
(void)SetUndeterminedFlag(eval, possible_parent_fg);
SetUndeterminedFlag(eval, possible_parent_fg);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.

View File

@ -0,0 +1,69 @@
import os
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_half_unroll_basic():
"""
Feature: Half unroll compile optimization for for statement.
Description: Only test for statement.
Expectation: Correct result and no exception.
"""
class ForLoopBasic(Cell):
def __init__(self):
super().__init__()
self.array = (Tensor(np.array(10).astype(np.int32)), Tensor(np.array(5).astype(np.int32)))
def construct(self, x):
output = x
for i in self.array:
output += i
return output
net = ForLoopBasic()
x = Tensor(np.array(10).astype(np.int32))
os.environ['MS_DEV_FOR_HALF_UNROLL'] = '1'
res = net(x)
os.environ['MS_DEV_FOR_HALF_UNROLL'] = ''
assert res == 25
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_half_unroll_if():
"""
Feature: Half unroll compile optimization for for statement.
Description: Test for-in statements.
Expectation: Correct result and no exception.
"""
class ForLoopIf(Cell):
def __init__(self):
super().__init__()
self.array = (Tensor(np.array(10).astype(np.int32)), Tensor(np.array(5).astype(np.int32)))
def construct(self, x):
output = x
for i in self.array:
if i < 10:
output += i
return output
net = ForLoopIf()
x = Tensor(np.array(10).astype(np.int32))
os.environ['MS_DEV_FOR_HALF_UNROLL'] = '1'
res = net(x)
os.environ['MS_DEV_FOR_HALF_UNROLL'] = ''
assert res == 15