!16332 Add environment variable to enable recursive evaluate.

From: @zh_qh
Reviewed-by: @ginfung,@kingxian
Signed-off-by: @ginfung
This commit is contained in:
mindspore-ci-bot 2021-05-17 11:00:40 +08:00 committed by Gitee
commit 8cfb5b8aea
5 changed files with 137 additions and 91 deletions

View File

@ -80,7 +80,8 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
// Increase & Check the func graph call depth.
engine->IncreaseFunctionCallDepth();
engine->IncreaseStackFrameDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
if (engine->function_call_depth() - engine->stack_frame_depth() >
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", (function call depth: " << engine->function_call_depth()
@ -157,10 +158,33 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
return res_base;
}
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
const AnfNodePtr &func_node = fg->get_return();
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
if (node->func_graph() != fg || node->isa<ValueNode>()) {
return EXCLUDE;
}
return FOLLOW;
});
AbstractBasePtr res_base = nullptr;
for (const auto &node : all_nodes) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
res_base = node_eval_result->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << res_base->ToString();
}
MS_EXCEPTION_IF_NULL(res_base);
return res_base;
}
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
MS_EXCEPTION_IF_NULL(engine);
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
if (engine->function_call_depth() - engine->stack_frame_depth() >
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", (function call depth: " << engine->function_call_depth()
@ -193,7 +217,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
<< ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
<< ", current function call depth: " << engine->function_call_depth();
auto res_base = LaunchStackFrame(engine, fg);
AbstractBasePtr res_base = nullptr;
if (engine->enable_recursive_eval()) {
res_base = LaunchRecursiveEval(engine, fg);
} else {
res_base = LaunchStackFrame(engine, fg);
}
MS_EXCEPTION_IF_NULL(res_base);
MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
<< ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub();

View File

@ -205,6 +205,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
AnalysisContextPtr parent_context_;
private:
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
// Add functions for stack frame routine.
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,

View File

@ -66,7 +66,7 @@ class StackFrame : public Base {
AnfNodePtr &CurrentNode() {
if (slot_index_ >= node_slots.size()) {
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToAbstract()
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToString()
<< " is invalid. Try to access frame sequence by index " << slot_index_
<< ", while the size is " << node_slots.size() << ".";
}

View File

@ -194,6 +194,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
stack_frame_depth_ = 0;
stack_frame_max_depth_ = 0;
forward_count_ = 0;
enable_recursive_eval_ = (common::GetEnv("ENV_RECURSIVE_EVAL") == "1");
}
~AnalysisEngine() = default;
@ -256,8 +258,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
}
function_call_depth_--;
}
size_t function_call_depth() { return function_call_depth_; }
size_t function_call_max_depth() { return function_call_max_depth_; }
size_t function_call_depth() const { return function_call_depth_; }
size_t function_call_max_depth() const { return function_call_max_depth_; }
void ResetStackFrameDepth() {
stack_frame_depth_ = 0;
@ -275,11 +277,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
}
stack_frame_depth_--;
}
size_t stack_frame_depth() { return stack_frame_depth_; }
size_t stack_frame_max_depth() { return stack_frame_max_depth_; }
size_t stack_frame_depth() const { return stack_frame_depth_; }
size_t stack_frame_max_depth() const { return stack_frame_max_depth_; }
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
bool enable_recursive_eval() const { return enable_recursive_eval_; }
private:
// Should compare Args based on value other than pointer;
struct EvaluatorArgs {
@ -343,6 +347,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
size_t forward_count_;
bool enable_recursive_eval_;
#ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_;
#endif

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
""" test control ops """
import os
import numpy as np
import pytest
@ -749,53 +750,6 @@ def test_while_scalar():
out = net(x, y)
def test_large_for_loop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.ReLU() # nn.Flatten()
def construct(self, x):
for elem in range(1, 1900):
x = self.flatten(x + elem)
return x
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=60)
with pytest.raises(RuntimeError) as err:
net(t)
context.set_context(max_call_depth=old_max_call_depth)
assert 'Exceed function call depth limit 60' in str(err.value)
def test_large_for_loop_with_continue_break():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.ReLU() # nn.Flatten()
def construct(self, x):
idx = 0
for elem1 in range(200):
idx = idx + 1
if idx < 10:
x = x + 0.5
continue
if idx > 500:
break
x = self.flatten(x + elem1)
return x
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=2000)
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
context.set_context(max_call_depth=old_max_call_depth)
def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)
@ -871,42 +825,6 @@ def test_parser_switch_layer_func_primitive():
net(i, input1)
def test_recursive_call():
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Dense(10, 10) # padding=0
# self.net2 = Net2()
def construct(self, x):
net2 = Net2()
x = net2(x)
out = self.fc(x)
return out
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.net = Net()
self.fc = nn.Dense(10, 10)
def construct(self, x):
x = self.net(x)
out = self.fc(x)
return out
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=80)
input_data = Tensor(np.identity(10).astype(np.float32))
net = Net2()
with pytest.raises(RuntimeError):
net(input_data)
context.set_context(max_call_depth=old_max_call_depth)
def test_switch_layer_shape_join_failed():
class AddFuncNet(nn.Cell):
def __init__(self, funcs, new_func):
@ -975,6 +893,29 @@ def test_switch_layer_dtype_join_failed():
net(i, inp)
def test_large_for_loop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.ReLU() # nn.Flatten()
def construct(self, x):
for elem in range(1, 1900):
x = self.flatten(x + elem)
return x
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
os.environ['ENV_RECURSIVE_EVAL'] = '1'
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=60)
with pytest.raises(RuntimeError) as err:
net(t)
context.set_context(max_call_depth=old_max_call_depth)
os.environ['ENV_RECURSIVE_EVAL'] = '0'
assert 'Exceed function call depth limit 60' in str(err.value)
def test_large_for_loop_case2():
class Menet(nn.Cell):
def __init__(self, axis, flag_boottom, flag_top):
@ -1000,9 +941,77 @@ def test_large_for_loop_case2():
x = Tensor(np.ones([2, 3], dtype=np.float32))
net = Menet(axis=0, flag_boottom=True, flag_top=True)
os.environ['ENV_RECURSIVE_EVAL'] = '1'
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=80)
with pytest.raises(RuntimeError) as err:
net(x)
os.environ['ENV_RECURSIVE_EVAL'] = '0'
context.set_context(max_call_depth=old_max_call_depth)
assert 'Exceed function call depth limit 80' in str(err.value)
def test_large_for_loop_with_continue_break():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.ReLU() # nn.Flatten()
def construct(self, x):
idx = 0
for elem1 in range(200):
idx = idx + 1
if idx < 10:
x = x + 0.5
continue
if idx > 500:
break
x = self.flatten(x + elem1)
return x
os.environ['ENV_RECURSIVE_EVAL'] = '1'
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=2000)
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)
os.environ['ENV_RECURSIVE_EVAL'] = '0'
context.set_context(max_call_depth=old_max_call_depth)
def test_recursive_call():
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Dense(10, 10) # padding=0
# self.net2 = Net2()
def construct(self, x):
net2 = Net2()
x = net2(x)
out = self.fc(x)
return out
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.net = Net()
self.fc = nn.Dense(10, 10)
def construct(self, x):
x = self.net(x)
out = self.fc(x)
return out
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
os.environ['ENV_RECURSIVE_EVAL'] = '1'
old_max_call_depth = context.get_context('max_call_depth')
context.set_context(max_call_depth=80)
input_data = Tensor(np.identity(10).astype(np.float32))
net = Net2()
with pytest.raises(RuntimeError):
net(input_data)
os.environ['ENV_RECURSIVE_EVAL'] = '0'
context.set_context(max_call_depth=old_max_call_depth)