!2765 fix large for loop segment fault
Merge pull request !2765 from fary86/fix_large_for_loop
This commit is contained in:
commit
323a80c620
|
@ -294,6 +294,12 @@ extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
|||
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
||||
extern const PrimitivePtr kPrimIsIndexedSlices;
|
||||
|
||||
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
|
||||
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
||||
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
const int MAX_FOR_LOOP_COUNT = 200;
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
|
|
|
@ -95,7 +95,7 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: condition, true branch, false branch
|
||||
if (args_spec_list.size() != 3) {
|
||||
|
@ -108,6 +108,11 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
auto fb = args_spec_list[2];
|
||||
MS_EXCEPTION_IF_NULL(cond);
|
||||
|
||||
auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG);
|
||||
if (unroll_flag != nullptr && GetValue<int>(unroll_flag) == 0) {
|
||||
return tb->Join(fb);
|
||||
}
|
||||
|
||||
ValuePtr v = cond->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
// for tensor as condition, keeps both true and false branch.
|
||||
|
|
|
@ -208,6 +208,11 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
|
|||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
|
||||
// and continue
|
||||
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
|
||||
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
|
||||
}
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
|
|
|
@ -294,13 +294,18 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node)
|
|||
// Perform a conditional jump using switch operation.
|
||||
// The first CNode select graph with condition, and than execute this graph
|
||||
void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) {
|
||||
const FunctionBlockPtr &false_block, bool unroll_loop) {
|
||||
if (func_graph()->get_return() != nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
|
||||
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
|
||||
}
|
||||
// Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch'
|
||||
auto prim_switch = std::make_shared<Primitive>(prim::kPrimSwitch->name());
|
||||
if (!unroll_loop) {
|
||||
prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0));
|
||||
}
|
||||
CNodePtr switch_app =
|
||||
func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
|
||||
func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
CNodePtr switch_app_new = func_graph()->NewCNode({switch_app});
|
||||
func_graph()->set_output(switch_app_new);
|
||||
|
|
|
@ -59,7 +59,8 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
|
||||
void Jump(const FunctionBlockPtr &block, AnfNodePtr node);
|
||||
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
|
||||
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock);
|
||||
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock,
|
||||
bool unroll_loop = true);
|
||||
// record the assign statement of self.xx weight parameter ,which will use state_setitem op
|
||||
void SetStateAssgin(const AnfNodePtr &target, const std::string &readid);
|
||||
void AddAutoDepend(const AnfNodePtr &target);
|
||||
|
|
|
@ -1002,6 +1002,7 @@ CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::
|
|||
AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
|
||||
return block->func_graph()->NewCNode({op_iter, iter_anf_node});
|
||||
}
|
||||
|
||||
CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
|
||||
const AnfNodePtr &op_hasnext) {
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
|
@ -1018,12 +1019,57 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
|
|||
// A for loop will generate 3 functions :the test, the body, and the continuation
|
||||
// for x in xs:
|
||||
// body
|
||||
// it compiled to be following statement
|
||||
// it is compiled to be following statement
|
||||
// if len(xs) < max_loop_cnt:
|
||||
// ParseForIter() // use iter to implement for loop, which always unroll loop
|
||||
// else:
|
||||
// ParseForLoop() // use loop var to implement for loop, which always sink loop
|
||||
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
|
||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
||||
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
|
||||
CNodePtr bool_node = block->func_graph()->NewCNode(
|
||||
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)});
|
||||
|
||||
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
|
||||
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
|
||||
MakeConditionBlocks(block, true_block, false_block);
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
|
||||
FunctionBlockPtr true_end = ParseForIter(true_block, node);
|
||||
true_end->Jump(after_block, nullptr);
|
||||
|
||||
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
|
||||
false_end->Jump(after_block, nullptr);
|
||||
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
after_block->Mature();
|
||||
return after_block;
|
||||
}
|
||||
|
||||
// A for loop will generate 3 functions :the test, the body, and the continuation
|
||||
// for x in xs:
|
||||
// body
|
||||
// it is compiled to be following statement
|
||||
// it = iter(xs)
|
||||
// while hastnext(it)
|
||||
// x, it = next(it)
|
||||
// body
|
||||
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
||||
FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
||||
|
@ -1088,6 +1134,91 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
|
|||
// No 'break', no end_block.
|
||||
return after_block;
|
||||
}
|
||||
|
||||
// A for loop will generate 3 functions :the test, the body, and the continuation
|
||||
// for x in xs:
|
||||
// body
|
||||
// it is compiled to be following statement
|
||||
// i = 0
|
||||
// while i < len(xs)
|
||||
// x = xs[i]
|
||||
// i = i + 1
|
||||
// body
|
||||
FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
|
||||
// get varibale name of 'x' in statement 'for x in xs'
|
||||
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(target_node, "id"));
|
||||
|
||||
// create statement 'len(xs)'
|
||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
|
||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
||||
MS_EXCEPTION_IF_NULL(iter_node);
|
||||
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
|
||||
|
||||
FunctionBlockPtr header_block =
|
||||
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
// create loop variable 'i'
|
||||
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
|
||||
// create loop condition 'i < len(xs)'
|
||||
CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter});
|
||||
|
||||
// generate the body of the for statement
|
||||
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(body_block);
|
||||
body_block->AddPrevBlock(header_block);
|
||||
// create 'x = xs[i]'
|
||||
CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var});
|
||||
target_var->debug_info()->set_name(name_id);
|
||||
body_block->WriteVariable(name_id, target_var);
|
||||
// create 'i = i + 1'
|
||||
CNodePtr loop_var_inc =
|
||||
body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)});
|
||||
body_block->WriteVariable(loop_var->name(), loop_var_inc);
|
||||
loop_var_inc->debug_info()->set_name(name_id);
|
||||
|
||||
// link the variable name with the target
|
||||
auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
|
||||
loop_var->debug_info()->set_trace_info(it_info);
|
||||
len_iter->debug_info()->set_trace_info(it_info);
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
|
||||
MS_EXCEPTION_IF_NULL(after_block);
|
||||
TraceManager::EndTrace();
|
||||
after_block->AddPrevBlock(header_block);
|
||||
|
||||
block->Jump(header_block, NewValueNode(0));
|
||||
body_block->Mature();
|
||||
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block, false);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, loop_var_inc};
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, loop_var_inc);
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
after_block->Mature();
|
||||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, nullptr);
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
// No 'break', no end_block.
|
||||
return after_block;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast IfExp";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
|
|
|
@ -106,6 +106,8 @@ class Parser {
|
|||
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a for statement
|
||||
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a function def statement
|
||||
FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a augment assign
|
||||
|
|
|
@ -87,6 +87,7 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
|
|||
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
|
||||
|
||||
// define the common name
|
||||
const char NAMED_PRIMITIVE_LEN[] = "len";
|
||||
const char NAMED_PRIMITIVE_ITER[] = "iter";
|
||||
const char NAMED_PRIMITIVE_NEXT[] = "next";
|
||||
const char NAMED_PRIMITIVE_GETITEM[] = "getitem";
|
||||
|
|
|
@ -621,11 +621,8 @@ void Pipeline::Run() {
|
|||
draw::Draw(base_name + ".dot", graph);
|
||||
// generate IR file in human readable format
|
||||
DumpIR(base_name + ".ir", graph);
|
||||
|
||||
// generate IR file in a heavily commented format, which can also be reloaded
|
||||
if (action.first != "parse") {
|
||||
ExportIR(base_name + ".dat", std::to_string(i), graph);
|
||||
}
|
||||
ExportIR(base_name + ".dat", std::to_string(i), graph);
|
||||
}
|
||||
#ifdef MS_DEBUG
|
||||
// Dump graph cnode list
|
||||
|
|
|
@ -600,3 +600,42 @@ def test_while_tensor():
|
|||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
y = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
out = net(x, y)
|
||||
|
||||
|
||||
def test_large_for_loop():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = P.ReLU() #nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
for elem in range(1, 19000):
|
||||
x = self.flatten(x + elem)
|
||||
return x
|
||||
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
|
||||
|
||||
def test_large_for_loop_with_continue_break():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = P.ReLU() #nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
idx = 0
|
||||
for elem1 in range(200):
|
||||
idx = idx + 1
|
||||
if idx < 10:
|
||||
x = x + 0.5
|
||||
continue
|
||||
if idx > 500:
|
||||
break
|
||||
x = self.flatten(x + elem1)
|
||||
return x
|
||||
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
|
|
Loading…
Reference in New Issue