forked from mindspore-Ecosystem/mindspore
!4898 Fix coredump caused by function call depth too large
Merge pull request !4898 from fary86/fix_call_depth_too_large
This commit is contained in:
commit
94a109f476
|
@ -113,6 +113,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.")
|
||||
.def("get_device_id", &mindspore::MsContext::device_id, "Get device id.")
|
||||
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
|
||||
.def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.")
|
||||
.def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.")
|
||||
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
|
||||
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
|
||||
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
|
||||
|
|
|
@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
const AnfNodePtr &func_node = fg->get_return();
|
||||
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
|
||||
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
|
||||
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
|
||||
<< ", current function call depth: " << engine->function_call_depth();
|
||||
AbstractBasePtr ret_base = nullptr;
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << ".";
|
||||
}
|
||||
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
|
||||
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
|
||||
const auto &node = *it;
|
||||
|
@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
|
||||
}
|
||||
engine->DecreaseFunctionCallDepth();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ret_base);
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
|
||||
|
|
|
@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
|
|||
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
|
||||
|
||||
// Running the analyzer.
|
||||
ResetFunctionCallDepth();
|
||||
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
|
||||
MS_EXCEPTION_IF_NULL(root_context);
|
||||
MS_EXCEPTION_IF_NULL(root_context->func_graph());
|
||||
|
|
|
@ -185,7 +185,9 @@ struct PartialAppHasher {
|
|||
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||
public:
|
||||
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
|
||||
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {}
|
||||
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {
|
||||
function_call_depth_ = 0;
|
||||
}
|
||||
~AnalysisEngine() = default;
|
||||
|
||||
// func_graph: The func_graph to analyze.
|
||||
|
@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
AnalysisCache cache_;
|
||||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
void ResetFunctionCallDepth() { function_call_depth_ = 0; }
|
||||
|
||||
void IncreaseFunctionCallDepth() { function_call_depth_++; }
|
||||
|
||||
void DecreaseFunctionCallDepth() {
|
||||
if (function_call_depth_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
|
||||
}
|
||||
function_call_depth_--;
|
||||
}
|
||||
|
||||
unsigned int function_call_depth() { return function_call_depth_; }
|
||||
|
||||
private:
|
||||
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
|
@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
const ConfigPtrList &args_conf_list);
|
||||
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
// record current depth of function call statck
|
||||
unsigned int function_call_depth_;
|
||||
|
||||
#ifdef DEBUG
|
||||
std::vector<AnfNodePtr> compute_conf_stack_;
|
||||
|
|
|
@ -234,6 +234,17 @@ class _Context:
|
|||
if not success:
|
||||
raise RuntimeError("Device id set failed!!!")
|
||||
|
||||
@property
|
||||
def max_call_depth(self):
|
||||
return self._context_handle.get_max_call_depth()
|
||||
|
||||
@max_call_depth.setter
|
||||
def max_call_depth(self, max_call_depth):
|
||||
if max_call_depth <= 0:
|
||||
raise ValueError(
|
||||
"Max call depth must be greater than 0, but got {}".format(max_call_depth))
|
||||
self._context_handle.set_max_call_depth(max_call_depth)
|
||||
|
||||
@property
|
||||
def enable_auto_mixed_precision(self):
|
||||
return self._context_handle.get_auto_mixed_precision_flag()
|
||||
|
@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
|
||||
data parallel training in the benefit of time and memory saving.
|
||||
max_call_depth(int): Specify the function call depth limit. Default: 1000.
|
||||
|
||||
|
||||
Raises:
|
||||
|
@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(max_call_depth=80)
|
||||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
@ -532,7 +545,7 @@ def reset_auto_parallel_context():
|
|||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||
enable_sparse=bool)
|
||||
enable_sparse=bool, max_call_depth=int)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Sets context for running environment.
|
||||
|
|
|
@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
} else {
|
||||
device_id_ = 0;
|
||||
}
|
||||
max_call_depth_ = MAX_CALL_DEPTH_DEFAULT;
|
||||
backend_policy_ = policy_map_[policy];
|
||||
device_target_ = target;
|
||||
execution_mode_ = kPynativeMode;
|
||||
|
|
|
@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend";
|
|||
const char kDavinciInferenceDevice[] = "AscendInference";
|
||||
const char kDavinciDevice[] = "Davinci";
|
||||
const char KNpuLog[] = "_npu_log";
|
||||
const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
|
||||
|
||||
const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
|
||||
// The default max available device memory is 1024GB.
|
||||
const float kDefaultMaxDeviceMemory = 1024;
|
||||
|
@ -80,6 +82,13 @@ class MsContext {
|
|||
uint32_t device_id() const { return device_id_; }
|
||||
bool set_device_id(uint32_t device_id);
|
||||
|
||||
// uint32_t max_call_depth_
|
||||
uint32_t max_call_depth() const { return max_call_depth_; }
|
||||
inline bool set_max_call_depth(uint32_t max_call_depth) {
|
||||
max_call_depth_ = max_call_depth;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool save_graphs_flag() const { return save_graphs_flag_; }
|
||||
void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; }
|
||||
|
||||
|
@ -171,6 +180,7 @@ class MsContext {
|
|||
MsBackendPolicy backend_policy_;
|
||||
std::string device_target_;
|
||||
uint32_t device_id_;
|
||||
uint32_t max_call_depth_;
|
||||
int execution_mode_;
|
||||
bool enable_pynative_infer_;
|
||||
bool enable_pynative_hook_;
|
||||
|
|
|
@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break():
|
|||
x = self.flatten(x + elem1)
|
||||
return x
|
||||
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=2000)
|
||||
t = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
net = Net()
|
||||
net(t)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
||||
|
||||
def test_mixed_precision_cast():
|
||||
|
@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive():
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
net(i, input1)
|
||||
|
||||
|
||||
def test_recursive_call():
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc = nn.Dense(10, 10) # padding=0
|
||||
#self.net2 = Net2()
|
||||
|
||||
def construct(self, x):
|
||||
net2 = Net2()
|
||||
x = net2(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.net = Net()
|
||||
self.fc = nn.Dense(10, 10)
|
||||
def construct(self, x):
|
||||
x = self.net(x)
|
||||
out = self.fc(x)
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
old_max_call_depth = context.get_context('max_call_depth')
|
||||
context.set_context(max_call_depth=80)
|
||||
input_data = Tensor(np.identity(10).astype(np.float32))
|
||||
net = Net2()
|
||||
with pytest.raises(RuntimeError):
|
||||
net(input_data)
|
||||
context.set_context(max_call_depth=old_max_call_depth)
|
||||
|
|
Loading…
Reference in New Issue