From 04524b6bd31669c6baaeca0990c3e25adc9dfa5b Mon Sep 17 00:00:00 2001 From: fary86 Date: Fri, 21 Aug 2020 11:29:20 +0800 Subject: [PATCH] Fix coredump caused by function call depth too large --- mindspore/ccsrc/pipeline/jit/init.cc | 2 + .../pipeline/jit/static_analysis/evaluator.cc | 8 +++- .../jit/static_analysis/static_analysis.cc | 1 + .../jit/static_analysis/static_analysis.h | 19 +++++++++- mindspore/context.py | 15 +++++++- mindspore/core/utils/ms_context.cc | 1 + mindspore/core/utils/ms_context.h | 10 +++++ tests/ut/python/ops/test_control_ops.py | 38 +++++++++++++++++++ 8 files changed, 91 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 3d4e7c6bdcc..9a2bb625c3e 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -110,6 +110,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") + .def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.") + .def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index f6ffda863ba..dd35f15dd6e 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr const AnfNodePtr &func_node = fg->get_return(); MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() - << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); + << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString() + << ", current function call depth: " << engine->function_call_depth(); AbstractBasePtr ret_base = nullptr; + engine->IncreaseFunctionCallDepth(); + if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) { + MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << "."; + } std::vector nodes = FastShadowSort(func_node); for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { const auto &node = *it; @@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); } + engine->DecreaseFunctionCallDepth(); MS_EXCEPTION_IF_NULL(ret_base); MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index acf891e5003..78d6a563c70 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); // Running the analyzer. + ResetFunctionCallDepth(); AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); MS_EXCEPTION_IF_NULL(root_context); MS_EXCEPTION_IF_NULL(root_context->func_graph()); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 70189328981..cfe667f252a 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -185,7 +185,9 @@ struct PartialAppHasher { class AnalysisEngine : public std::enable_shared_from_this { public: AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) - : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} + : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) { + function_call_depth_ = 0; + } ~AnalysisEngine() = default; // func_graph: The func_graph to analyze. @@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this { AnalysisCache cache_; std::unordered_map prim_py_evaluators_; + void ResetFunctionCallDepth() { function_call_depth_ = 0; } + + void IncreaseFunctionCallDepth() { function_call_depth_++; } + + void DecreaseFunctionCallDepth() { + if (function_call_depth_ == 0) { + MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it."; + } + function_call_depth_--; + } + + unsigned int function_call_depth() { return function_call_depth_; } + private: void SetUndeterminedFlag(const EvaluatorPtr &evaluator); EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, @@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this { const ConfigPtrList &args_conf_list); EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list); + // record current depth of function call statck + unsigned int function_call_depth_; #ifdef DEBUG std::vector compute_conf_stack_; diff --git a/mindspore/context.py b/mindspore/context.py index 1f5fe65f716..c92a9985142 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -234,6 +234,17 @@ class _Context: if not success: raise RuntimeError("Device id set failed!!!") + @property + def max_call_depth(self): + return self._context_handle.get_max_call_depth() + + @max_call_depth.setter + def max_call_depth(self, max_call_depth): + if max_call_depth <= 0: + raise ValueError( + "Max call depth must be greater than 0, but got {}".format(max_call_depth)) + self._context_handle.set_max_call_depth(max_call_depth) + @property def enable_auto_mixed_precision(self): return self._context_handle.get_auto_mixed_precision_flag() @@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs): full_batch (bool): Whether to load the whole batch on each device. Default: False. enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in data parallel training in the benefit of time and memory saving. + max_call_depth(int): Specify the function call depth limit. Default: 1000. Raises: @@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") + >>> context.set_auto_parallel_context(max_call_depth=80) """ _set_auto_parallel_context(**kwargs) @@ -532,7 +545,7 @@ def reset_auto_parallel_context(): save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, - enable_sparse=bool) + enable_sparse=bool, max_call_depth=int) def set_context(**kwargs): """ Sets context for running environment. diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index a1ad034c009..42453d99042 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { } else { device_id_ = 0; } + max_call_depth_ = MAX_CALL_DEPTH_DEFAULT; backend_policy_ = policy_map_[policy]; device_target_ = target; execution_mode_ = kPynativeMode; diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 9ad3259b24f..9d29ad5bdc8 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend"; const char kDavinciInferenceDevice[] = "AscendInference"; const char kDavinciDevice[] = "Davinci"; const char KNpuLog[] = "_npu_log"; +const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; + const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; // The default max available device memory is 1024GB. const float kDefaultMaxDeviceMemory = 1024; @@ -80,6 +82,13 @@ class MsContext { uint32_t device_id() const { return device_id_; } bool set_device_id(uint32_t device_id); + // uint32_t max_call_depth_ + uint32_t max_call_depth() const { return max_call_depth_; } + inline bool set_max_call_depth(uint32_t max_call_depth) { + max_call_depth_ = max_call_depth; + return true; + } + bool save_graphs_flag() const { return save_graphs_flag_; } void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } @@ -171,6 +180,7 @@ class MsContext { MsBackendPolicy backend_policy_; std::string device_target_; uint32_t device_id_; + uint32_t max_call_depth_; int execution_mode_; bool enable_pynative_infer_; bool enable_pynative_hook_; diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 26132165b5f..483485ea64d 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break(): x = self.flatten(x + elem1) return x + old_max_call_depth = context.get_context('max_call_depth') + context.set_context(max_call_depth=2000) t = Tensor(np.ones([2, 3], dtype=np.float32)) net = Net() net(t) + context.set_context(max_call_depth=old_max_call_depth) def test_mixed_precision_cast(): @@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive(): with pytest.raises(ValueError): net(i, input1) + + +def test_recursive_call(): + class Net(nn.Cell): + """ Net definition """ + + def __init__(self): + super(Net, self).__init__() + self.fc = nn.Dense(10, 10) # padding=0 + #self.net2 = Net2() + + def construct(self, x): + net2 = Net2() + x = net2(x) + out = self.fc(x) + return out + + class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.net = Net() + self.fc = nn.Dense(10, 10) + def construct(self, x): + x = self.net(x) + out = self.fc(x) + return out + + context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + old_max_call_depth = context.get_context('max_call_depth') + context.set_context(max_call_depth=80) + input_data = Tensor(np.identity(10).astype(np.float32)) + net = Net2() + with pytest.raises(RuntimeError): + net(input_data) + context.set_context(max_call_depth=old_max_call_depth)