forked from mindspore-Ecosystem/mindspore
!16332 Add environment variable to enable recursive evaluate.
From: @zh_qh Reviewed-by: @ginfung,@kingxian Signed-off-by: @ginfung
This commit is contained in:
commit
8cfb5b8aea
|
@ -80,7 +80,8 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
|
|||
// Increase & Check the func graph call depth.
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
engine->IncreaseStackFrameDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
if (engine->function_call_depth() - engine->stack_frame_depth() >
|
||||
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
<< ", (function call depth: " << engine->function_call_depth()
|
||||
|
@ -157,10 +158,33 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
return res_base;
|
||||
}
|
||||
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
|
||||
if (node->func_graph() != fg || node->isa<ValueNode>()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
return FOLLOW;
|
||||
});
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
for (const auto &node : all_nodes) {
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context_);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString();
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
res_base = node_eval_result->abstract();
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << res_base->ToString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
return res_base;
|
||||
}
|
||||
|
||||
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
if (engine->function_call_depth() - engine->stack_frame_depth() >
|
||||
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
<< ", (function call depth: " << engine->function_call_depth()
|
||||
|
@ -193,7 +217,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
|
||||
<< ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
|
||||
<< ", current function call depth: " << engine->function_call_depth();
|
||||
auto res_base = LaunchStackFrame(engine, fg);
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
if (engine->enable_recursive_eval()) {
|
||||
res_base = LaunchRecursiveEval(engine, fg);
|
||||
} else {
|
||||
res_base = LaunchStackFrame(engine, fg);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
|
||||
<< ", evaluated abstract: " << res_base->ToString() << ", is stub: " << fg->stub();
|
||||
|
|
|
@ -205,6 +205,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
AnalysisContextPtr parent_context_;
|
||||
|
||||
private:
|
||||
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
// Add functions for stack frame routine.
|
||||
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
|
|
|
@ -66,7 +66,7 @@ class StackFrame : public Base {
|
|||
|
||||
AnfNodePtr &CurrentNode() {
|
||||
if (slot_index_ >= node_slots.size()) {
|
||||
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToAbstract()
|
||||
MS_LOG(EXCEPTION) << "The stack frame of " << func_graph_->ToString()
|
||||
<< " is invalid. Try to access frame sequence by index " << slot_index_
|
||||
<< ", while the size is " << node_slots.size() << ".";
|
||||
}
|
||||
|
|
|
@ -194,6 +194,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
stack_frame_depth_ = 0;
|
||||
stack_frame_max_depth_ = 0;
|
||||
forward_count_ = 0;
|
||||
|
||||
enable_recursive_eval_ = (common::GetEnv("ENV_RECURSIVE_EVAL") == "1");
|
||||
}
|
||||
~AnalysisEngine() = default;
|
||||
|
||||
|
@ -256,8 +258,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
}
|
||||
function_call_depth_--;
|
||||
}
|
||||
size_t function_call_depth() { return function_call_depth_; }
|
||||
size_t function_call_max_depth() { return function_call_max_depth_; }
|
||||
size_t function_call_depth() const { return function_call_depth_; }
|
||||
size_t function_call_max_depth() const { return function_call_max_depth_; }
|
||||
|
||||
void ResetStackFrameDepth() {
|
||||
stack_frame_depth_ = 0;
|
||||
|
@ -275,11 +277,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
}
|
||||
stack_frame_depth_--;
|
||||
}
|
||||
size_t stack_frame_depth() { return stack_frame_depth_; }
|
||||
size_t stack_frame_max_depth() { return stack_frame_max_depth_; }
|
||||
size_t stack_frame_depth() const { return stack_frame_depth_; }
|
||||
size_t stack_frame_max_depth() const { return stack_frame_max_depth_; }
|
||||
|
||||
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
|
||||
|
||||
bool enable_recursive_eval() const { return enable_recursive_eval_; }
|
||||
|
||||
private:
|
||||
// Should compare Args based on value other than pointer;
|
||||
struct EvaluatorArgs {
|
||||
|
@ -343,6 +347,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
|
||||
size_t forward_count_;
|
||||
|
||||
bool enable_recursive_eval_;
|
||||
|
||||
#ifdef DEBUG
|
||||
std::vector<AnfNodePtr> compute_conf_stack_;
|
||||
#endif
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test control ops """
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -749,53 +750,6 @@ def test_while_scalar():
|
|||
out = net(x, y)
|
||||
|
||||
|
||||
def test_large_for_loop():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = P.ReLU() # nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
for elem in range(1, 1900):
|
||||
x = self.flatten(x + elem)
|
||||
return x
|
||||
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=60)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net(t)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
assert 'Exceed function call depth limit 60' in str(err.value)
|
||||
|
||||
|
||||
def test_large_for_loop_with_continue_break():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = P.ReLU() # nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
idx = 0
|
||||
for elem1 in range(200):
|
||||
idx = idx + 1
|
||||
if idx < 10:
|
||||
x = x + 0.5
|
||||
continue
|
||||
if idx > 500:
|
||||
break
|
||||
x = self.flatten(x + elem1)
|
||||
return x
|
||||
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=2000)
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
||||
|
||||
def test_mixed_precision_cast():
|
||||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
z = F.mixed_precision_cast(mstype.float16, x)
|
||||
|
@ -871,42 +825,6 @@ def test_parser_switch_layer_func_primitive():
|
|||
net(i, input1)
|
||||
|
||||
|
||||
def test_recursive_call():
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc = nn.Dense(10, 10) # padding=0
|
||||
# self.net2 = Net2()
|
||||
|
||||
def construct(self, x):
|
||||
net2 = Net2()
|
||||
x = net2(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.net = Net()
|
||||
self.fc = nn.Dense(10, 10)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.net(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=80)
|
||||
input_data = Tensor(np.identity(10).astype(np.float32))
|
||||
net = Net2()
|
||||
with pytest.raises(RuntimeError):
|
||||
net(input_data)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
||||
|
||||
def test_switch_layer_shape_join_failed():
|
||||
class AddFuncNet(nn.Cell):
|
||||
def __init__(self, funcs, new_func):
|
||||
|
@ -975,6 +893,29 @@ def test_switch_layer_dtype_join_failed():
|
|||
net(i, inp)
|
||||
|
||||
|
||||
def test_large_for_loop():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = P.ReLU() # nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
for elem in range(1, 1900):
|
||||
x = self.flatten(x + elem)
|
||||
return x
|
||||
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '1'
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=60)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net(t)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '0'
|
||||
assert 'Exceed function call depth limit 60' in str(err.value)
|
||||
|
||||
|
||||
def test_large_for_loop_case2():
|
||||
class Menet(nn.Cell):
|
||||
def __init__(self, axis, flag_boottom, flag_top):
|
||||
|
@ -1000,9 +941,77 @@ def test_large_for_loop_case2():
|
|||
|
||||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Menet(axis=0, flag_boottom=True, flag_top=True)
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '1'
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=80)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
net(x)
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '0'
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
assert 'Exceed function call depth limit 80' in str(err.value)
|
||||
|
||||
|
||||
def test_large_for_loop_with_continue_break():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flatten = P.ReLU() # nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
idx = 0
|
||||
for elem1 in range(200):
|
||||
idx = idx + 1
|
||||
if idx < 10:
|
||||
x = x + 0.5
|
||||
continue
|
||||
if idx > 500:
|
||||
break
|
||||
x = self.flatten(x + elem1)
|
||||
return x
|
||||
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '1'
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=2000)
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '0'
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
||||
|
||||
def test_recursive_call():
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc = nn.Dense(10, 10) # padding=0
|
||||
# self.net2 = Net2()
|
||||
|
||||
def construct(self, x):
|
||||
net2 = Net2()
|
||||
x = net2(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.net = Net()
|
||||
self.fc = nn.Dense(10, 10)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.net(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '1'
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=80)
|
||||
input_data = Tensor(np.identity(10).astype(np.float32))
|
||||
net = Net2()
|
||||
with pytest.raises(RuntimeError):
|
||||
net(input_data)
|
||||
os.environ['ENV_RECURSIVE_EVAL'] = '0'
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
|
Loading…
Reference in New Issue