!5136 fix large for loop runtime error due to lacking of backend operators

Merge pull request !5136 from fary86/fix_large_for_runtime_error
This commit is contained in:
mindspore-ci-bot 2020-08-25 22:26:25 +08:00 committed by Gitee
commit 971716f4b9
2 changed files with 46 additions and 3 deletions

View File

@ -19,6 +19,7 @@
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
#include <limits>
#include <vector>
#include <string>
#include <map>
@ -50,7 +51,11 @@ enum ParseStatusCode : int {
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
const int MAX_FOR_LOOP_COUNT = 600;
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
// not implemented, so here set MAX_FOR_LOOP_COUNT to int max limit to override default value `600`. This will make
// the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
// when function call depth execeeds the limit `context.get_context('max_call_depth')`.
const int MAX_FOR_LOOP_COUNT = std::numeric_limits<int>::max();
class AstNodeType;
class ParseAst;

View File

@ -773,13 +773,18 @@ def test_large_for_loop():
self.flatten = P.ReLU() #nn.Flatten()
def construct(self, x):
for elem in range(1, 19000):
for elem in range(1, 1900):
x = self.flatten(x + elem)
return x
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=60)
with pytest.raises(RuntimeError) as err:
net(t)
context.set_context(max_call_depth=old_max_call_depth)
assert 'Exceed function call depth limit 60' in str(err.value)
def test_large_for_loop_with_continue_break():
@ -986,3 +991,36 @@ def test_switch_layer_dtype_join_failed():
with pytest.raises(TypeError) as err:
net(i, inp)
def test_large_for_loop_case2():
class Menet(nn.Cell):
def __init__(self, axis, flag_boottom, flag_top):
super(Menet, self).__init__()
self.squeeze = P.Squeeze(axis)
self.expanddims = P.ExpandDims()
self.flatten = nn.Flatten()
self.neg = P.Neg()
self.axis = axis
self.flag_boottom = flag_boottom
self.flag_top = flag_top
def construct(self, x):
if self.flag_boottom:
x = self.neg(x)
for i in range(0, 1500):
x = self.expanddims(x, self.axis)
x = self.squeeze(x)
x = self.flatten(x)
if self.flag_top:
x = self.neg(x)
return x
x = Tensor(np.ones([2, 3], dtype=np.float32))
net = Menet(axis=0, flag_boottom=True, flag_top=True)
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=80)
with pytest.raises(RuntimeError) as err:
net(x)
context.set_context(max_call_depth=old_max_call_depth)
assert 'Exceed function call depth limit 80' in str(err.value)