forked from mindspore-Ecosystem/mindspore
!5136 fix large for loop runtime error due to lacking of backend operators
Merge pull request !5136 from fary86/fix_large_for_runtime_error
This commit is contained in:
commit
971716f4b9
|
@ -19,6 +19,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
@ -50,7 +51,11 @@ enum ParseStatusCode : int {
|
|||
|
||||
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
const int MAX_FOR_LOOP_COUNT = 600;
|
||||
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
|
||||
// not implemented, so here set MAX_FOR_LOOP_COUNT to int max limit to override default value `600`. This will make
|
||||
// the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
|
||||
// when function call depth execeeds the limit `context.get_context('max_call_depth')`.
|
||||
const int MAX_FOR_LOOP_COUNT = std::numeric_limits<int>::max();
|
||||
|
||||
class AstNodeType;
|
||||
class ParseAst;
|
||||
|
|
|
@ -773,13 +773,18 @@ def test_large_for_loop():
|
|||
self.flatten = P.ReLU() #nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
for elem in range(1, 19000):
|
||||
for elem in range(1, 1900):
|
||||
x = self.flatten(x + elem)
|
||||
return x
|
||||
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=60)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net(t)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
assert 'Exceed function call depth limit 60' in str(err.value)
|
||||
|
||||
|
||||
def test_large_for_loop_with_continue_break():
|
||||
|
@ -986,3 +991,36 @@ def test_switch_layer_dtype_join_failed():
|
|||
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(i, inp)
|
||||
|
||||
|
||||
def test_large_for_loop_case2():
|
||||
class Menet(nn.Cell):
|
||||
def __init__(self, axis, flag_boottom, flag_top):
|
||||
super(Menet, self).__init__()
|
||||
self.squeeze = P.Squeeze(axis)
|
||||
self.expanddims = P.ExpandDims()
|
||||
self.flatten = nn.Flatten()
|
||||
self.neg = P.Neg()
|
||||
self.axis = axis
|
||||
self.flag_boottom = flag_boottom
|
||||
self.flag_top = flag_top
|
||||
|
||||
def construct(self, x):
|
||||
if self.flag_boottom:
|
||||
x = self.neg(x)
|
||||
for i in range(0, 1500):
|
||||
x = self.expanddims(x, self.axis)
|
||||
x = self.squeeze(x)
|
||||
x = self.flatten(x)
|
||||
if self.flag_top:
|
||||
x = self.neg(x)
|
||||
return x
|
||||
|
||||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Menet(axis=0, flag_boottom=True, flag_top=True)
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=80)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net(x)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
assert 'Exceed function call depth limit 80' in str(err.value)
|
||||
|
|
Loading…
Reference in New Issue