forked from mindspore-Ecosystem/mindspore
Lift all fv in if true and false branches.
This commit is contained in:
parent
d4346e26d7
commit
4147bb0ad5
|
@ -515,7 +515,7 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector
|
|||
|
||||
// Perform a conditional jump using switch operation.
|
||||
// The first CNode select graph with condition, and than execute this graph
|
||||
void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
|
||||
CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
|
||||
const AnfNodePtr &false_block_call) {
|
||||
MS_EXCEPTION_IF_NULL(true_block_call);
|
||||
MS_EXCEPTION_IF_NULL(false_block_call);
|
||||
|
@ -527,13 +527,14 @@ void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePt
|
|||
func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, true_block_call, false_block_call});
|
||||
CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
|
||||
func_graph_->set_output(switch_app_new);
|
||||
return switch_app_new;
|
||||
}
|
||||
|
||||
void FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
|
||||
CNodePtr FunctionBlock::ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) {
|
||||
MS_EXCEPTION_IF_NULL(true_block);
|
||||
MS_EXCEPTION_IF_NULL(false_block);
|
||||
ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph()));
|
||||
return ConditionalJump(cond_node, NewValueNode(true_block->func_graph()), NewValueNode(false_block->func_graph()));
|
||||
}
|
||||
|
||||
// Create cnode for the assign statement like 'self.target = source'.
|
||||
|
|
|
@ -63,9 +63,9 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
|
||||
void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
|
||||
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
|
||||
void ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
|
||||
CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const AnfNodePtr &true_block_call,
|
||||
const AnfNodePtr &false_block_call);
|
||||
void ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
|
||||
CNodePtr ConditionalJump(const AnfNodePtr &cond_node, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block);
|
||||
// Create cnode for the assign statement like self.target = source.
|
||||
void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source);
|
||||
|
|
|
@ -131,9 +131,9 @@ void Parser::CleanParserResource() {
|
|||
|
||||
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunctionAst> &ast) {
|
||||
// Check whether the functions referred by this function and itself are missing 'return' statement
|
||||
auto mng = Manage(fn, false);
|
||||
auto manager = Manage(fn, false);
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
for (const auto &func_graph : mng->func_graphs()) {
|
||||
for (const auto &func_graph : manager->func_graphs()) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->get_return() != nullptr) {
|
||||
continue;
|
||||
|
@ -151,17 +151,14 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::LiftRolledBodyGraphFV() {
|
||||
for (auto &rolled_call_pair : rolled_body_calls_) {
|
||||
auto rolled_call_cnode = rolled_call_pair.first;
|
||||
auto rolled_graph = rolled_call_pair.second->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(rolled_graph);
|
||||
std::vector<std::pair<CNodePtr, size_t>> GetFreeVariable(const FuncGraphPtr &func_graph) {
|
||||
// Considering the performance, we didn't use Manager here.
|
||||
std::vector<std::pair<CNodePtr, size_t>> free_variables;
|
||||
std::vector<AnfNodePtr> nodes =
|
||||
TopoSort(rolled_graph->get_return(), SuccIncoming, [&rolled_graph](const AnfNodePtr &node) -> IncludeType {
|
||||
TopoSort(func_graph->get_return(), SuccIncoming, [&func_graph](const AnfNodePtr &node) -> IncludeType {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Not follow FV's inputs.
|
||||
if (node->func_graph() != nullptr && node->func_graph() != rolled_graph) {
|
||||
if (node->func_graph() != nullptr && node->func_graph() != func_graph) {
|
||||
return NOFOLLOW;
|
||||
}
|
||||
return FOLLOW;
|
||||
|
@ -169,19 +166,28 @@ void Parser::LiftRolledBodyGraphFV() {
|
|||
for (auto &node : nodes) {
|
||||
// Only check Non-FV CNode.
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr || cnode->func_graph() != rolled_graph) {
|
||||
if (cnode == nullptr || cnode->func_graph() != func_graph) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
|
||||
auto &input = cnode->input(i);
|
||||
if (input->func_graph() != nullptr && input->func_graph() != rolled_graph) {
|
||||
(void)free_variables.emplace_back(std::pair(cnode, i));
|
||||
if (input->func_graph() != nullptr && input->func_graph() != func_graph) {
|
||||
(void)free_variables.emplace_back(std::make_pair(cnode, i));
|
||||
constexpr auto recur_2 = 2;
|
||||
MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
return free_variables;
|
||||
}
|
||||
|
||||
void Parser::LiftRolledBodyGraphFV() {
|
||||
for (auto &rolled_call_pair : rolled_body_calls_) {
|
||||
auto rolled_call_cnode = rolled_call_pair.first;
|
||||
auto rolled_graph = rolled_call_pair.second->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(rolled_graph);
|
||||
const auto &free_variables = GetFreeVariable(rolled_graph);
|
||||
for (auto &free_node_pair : free_variables) {
|
||||
auto &cnode = free_node_pair.first;
|
||||
auto index = free_node_pair.second;
|
||||
|
@ -197,6 +203,46 @@ void Parser::LiftRolledBodyGraphFV() {
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::LiftIfBranchGraphFV() {
|
||||
for (auto &branch_call_tuple : if_branch_calls_) {
|
||||
auto call_cnode = std::get<0>(branch_call_tuple);
|
||||
auto true_branch_graph = std::get<1>(branch_call_tuple)->func_graph();
|
||||
auto false_branch_graph = std::get<2>(branch_call_tuple)->func_graph();
|
||||
const auto &true_free_variables = GetFreeVariable(true_branch_graph);
|
||||
const auto &false_free_variables = GetFreeVariable(false_branch_graph);
|
||||
// Handle true branch.
|
||||
for (auto &free_node_pair : true_free_variables) {
|
||||
auto &cnode = free_node_pair.first;
|
||||
auto index = free_node_pair.second;
|
||||
// Move the free variable to parent.
|
||||
auto &free_node = cnode->input(index);
|
||||
call_cnode->add_input(free_node);
|
||||
// Change the free variable to the parameter.
|
||||
auto parameter = true_branch_graph->add_parameter();
|
||||
cnode->set_input(index, parameter);
|
||||
// Add a unused parameter in other branch.
|
||||
(void)false_branch_graph->add_parameter();
|
||||
constexpr auto recur_2 = 2;
|
||||
MS_LOG(DEBUG) << "True branch, change FV: " << cnode->DebugString(recur_2);
|
||||
}
|
||||
// Handle false branch.
|
||||
for (auto &free_node_pair : false_free_variables) {
|
||||
auto &cnode = free_node_pair.first;
|
||||
auto index = free_node_pair.second;
|
||||
// Move the free variable to parent.
|
||||
auto &free_node = cnode->input(index);
|
||||
call_cnode->add_input(free_node);
|
||||
// Change the free variable to the parameter.
|
||||
auto parameter = false_branch_graph->add_parameter();
|
||||
cnode->set_input(index, parameter);
|
||||
// Add a unused parameter in other branch.
|
||||
(void)true_branch_graph->add_parameter();
|
||||
constexpr auto recur_2 = 2;
|
||||
MS_LOG(DEBUG) << "False branch, change FV: " << cnode->DebugString(recur_2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph,
|
||||
size_t middle_graph_output_cnode_size, bool use_arguments_pack) {
|
||||
|
@ -322,6 +368,8 @@ void Parser::TransformParallelCall() {
|
|||
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
|
||||
}
|
||||
|
||||
// Lift inner, then lift outer.
|
||||
LiftIfBranchGraphFV();
|
||||
LiftRolledBodyGraphFV();
|
||||
}
|
||||
|
||||
|
@ -1600,7 +1648,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString()
|
||||
<< ", false_end: " << false_end->ToString();
|
||||
}
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
|
||||
|
||||
// Record the former, middle, latter graphs info.
|
||||
if (true_end->func_graph()->get_return() != nullptr || false_end->func_graph()->get_return() != nullptr) {
|
||||
|
@ -1622,6 +1670,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
<< ", middle: " << false_branch_graphs.second->func_graph()->ToString() << "}";
|
||||
}
|
||||
|
||||
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
||||
if (transform_for_half_unroll_call) {
|
||||
// Lift the if branches in for statement.
|
||||
if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
|
||||
}
|
||||
|
||||
if (after_block->prev_blocks().empty()) {
|
||||
after_block->SetAsDeadBlock();
|
||||
}
|
||||
|
@ -1887,7 +1941,7 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
<< ", middle: " << loop_graphs.second->func_graph()->ToString() << "}";
|
||||
// Record the rolled body function, for later lifting operation.
|
||||
if (rolled_body_call != nullptr) {
|
||||
(void)rolled_body_calls_.emplace_back(std::pair(rolled_body_call, rolled_body_block));
|
||||
(void)rolled_body_calls_.emplace_back(std::make_pair(rolled_body_call, rolled_body_block));
|
||||
constexpr int recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "Record rolled body call: {CNode: " << rolled_body_call->DebugString(recursive_level)
|
||||
<< ", rolled_graph: " << rolled_body_block->ToString() << "}";
|
||||
|
@ -2509,13 +2563,13 @@ void Parser::RemoveUnnecessaryPhis() {
|
|||
if (removable_phis.empty()) {
|
||||
return;
|
||||
}
|
||||
auto mng = Manage(func_graph_, false);
|
||||
auto manager = Manage(func_graph_, false);
|
||||
// Replace the nodes
|
||||
// Remove from inside to outside
|
||||
for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
|
||||
auto phi = phis[LongToSize(idx)];
|
||||
auto new_node = FindPhis(removable_phis, phi);
|
||||
mng->Replace(phi, new_node);
|
||||
manager->Replace(phi, new_node);
|
||||
}
|
||||
// Remove the parameter
|
||||
for (FunctionBlockPtr &block : func_block_list_) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
@ -208,6 +209,7 @@ class Parser {
|
|||
// Transform tail call to parallel call.
|
||||
void TransformParallelCall();
|
||||
void LiftRolledBodyGraphFV();
|
||||
void LiftIfBranchGraphFV();
|
||||
|
||||
// If Tensor is present as type, not Tensor(xxx), should not make InterpretNode.
|
||||
bool IsTensorType(const AnfNodePtr &node, const std::string &script_text) const;
|
||||
|
@ -332,6 +334,8 @@ class Parser {
|
|||
// The func graphs to transform tail call ir to independent call ir.
|
||||
// Contains: {former_graph, middle_graph}, latter_graph is no need.
|
||||
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> parallel_call_graphs_;
|
||||
// The true branch and false branch call info. of if statement.
|
||||
std::vector<std::tuple<CNodePtr, FunctionBlockPtr, FunctionBlockPtr>> if_branch_calls_;
|
||||
// The rolled_body callers info. for later lifting operation.
|
||||
std::vector<std::pair<CNodePtr, FunctionBlockPtr>> rolled_body_calls_;
|
||||
// Add exception for if parallel transform.
|
||||
|
|
|
@ -1042,7 +1042,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
for (auto eval : evaluators) {
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
(void)SetUndeterminedFlag(eval, possible_parent_fg);
|
||||
SetUndeterminedFlag(eval, possible_parent_fg);
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_half_unroll_basic():
|
||||
"""
|
||||
Feature: Half unroll compile optimization for for statement.
|
||||
Description: Only test for statement.
|
||||
Expectation: Correct result and no exception.
|
||||
"""
|
||||
class ForLoopBasic(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.array = (Tensor(np.array(10).astype(np.int32)), Tensor(np.array(5).astype(np.int32)))
|
||||
|
||||
def construct(self, x):
|
||||
output = x
|
||||
for i in self.array:
|
||||
output += i
|
||||
|
||||
return output
|
||||
|
||||
net = ForLoopBasic()
|
||||
x = Tensor(np.array(10).astype(np.int32))
|
||||
os.environ['MS_DEV_FOR_HALF_UNROLL'] = '1'
|
||||
res = net(x)
|
||||
os.environ['MS_DEV_FOR_HALF_UNROLL'] = ''
|
||||
assert res == 25
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_half_unroll_if():
|
||||
"""
|
||||
Feature: Half unroll compile optimization for for statement.
|
||||
Description: Test for-in statements.
|
||||
Expectation: Correct result and no exception.
|
||||
"""
|
||||
class ForLoopIf(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.array = (Tensor(np.array(10).astype(np.int32)), Tensor(np.array(5).astype(np.int32)))
|
||||
|
||||
def construct(self, x):
|
||||
output = x
|
||||
for i in self.array:
|
||||
if i < 10:
|
||||
output += i
|
||||
|
||||
return output
|
||||
|
||||
net = ForLoopIf()
|
||||
x = Tensor(np.array(10).astype(np.int32))
|
||||
os.environ['MS_DEV_FOR_HALF_UNROLL'] = '1'
|
||||
res = net(x)
|
||||
os.environ['MS_DEV_FOR_HALF_UNROLL'] = ''
|
||||
assert res == 15
|
Loading…
Reference in New Issue