!20401 infer exception handle

Merge pull request !20401 from lanzhineng/infer_optv5
This commit is contained in:
i-robot 2021-07-23 01:46:46 +00:00 committed by Gitee
commit 088287f346
7 changed files with 89 additions and 67 deletions

View File

@ -89,9 +89,9 @@ void DumpInferStack(std::ostringstream &oss) {
}
std::vector<std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>> infer_vec;
while (!graph_stack.empty()) {
auto top = graph_stack.top();
auto top = graph_stack.back();
infer_vec.push_back(top);
graph_stack.pop();
graph_stack.pop_back();
}
std::reverse(infer_vec.begin(), infer_vec.end());
int index = 0;
@ -118,12 +118,12 @@ void TraceGraphEval() {
MS_LOG(INFO) << "Length of analysis graph stack is empty.";
return;
}
MS_LOG(ERROR) << "\n*******************************graph evaluate stack**********************************";
std::ostringstream oss;
oss << "\n*******************************graph evaluate stack**********************************";
oss << std::endl;
DumpInferStack(oss);
oss << "\n*************************************************************************************";
MS_LOG(ERROR) << oss.str();
MS_LOG(ERROR) << "\n*************************************************************************************";
}
class AnalyzeFailExporter : public AnfExporter {
@ -131,7 +131,7 @@ class AnalyzeFailExporter : public AnfExporter {
AnalyzeFailExporter() : AnfExporter(true, false) {}
~AnalyzeFailExporter() override = default;
bool ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack);
bool ExportFuncGraph(const std::string &filename, const TraceCNodeEvalStack &node_config_stack);
private:
void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx,
@ -339,8 +339,7 @@ void AnalyzeFailExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode,
ofs << "\n";
}
bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr> &node_config_stack) {
bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename, const TraceCNodeEvalStack &node_config_stack) {
if (node_config_stack.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return false;
@ -398,8 +397,7 @@ void GetEvalStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Length of analysis information stack is empty.";
return;
}
static int fileNumber = 0;
string file_name = "analyze_fail_" + std::to_string(fileNumber++) + ".dat";
string file_name = "analyze_fail.dat";
auto ms_om_path = common::GetEnv("MS_OM_PATH");
if (!ms_om_path.empty()) {
auto path = ms_om_path + "/" + file_name;
@ -443,62 +441,68 @@ void GetEvalStackInfo(std::ostringstream &oss) {
}
// Trace the graph evaluator stack
thread_local static std::stack<std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
thread_local TraceGraphEvalStack graph_infer_stack;
// Trace the cnode infer debug info
thread_local static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
thread_local TraceCNodeEvalStack cnode_debug_stack{};
void TraceGraphEvalEnter(const abstract::AnalysisContextPtr &context, const abstract::AnfNodeConfigPtr &node) {
if (context == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null context";
}
(void)graph_infer_stack.emplace(std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>(context, node));
(void)graph_infer_stack.push_back(std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>(context, node));
}
void TraceGraphEvalLeave(const abstract::AnalysisContextPtr &context) {
if (context == nullptr || graph_infer_stack.empty()) {
MS_LOG(EXCEPTION) << "The context is null, or call stack is empty.";
}
if (context != graph_infer_stack.top().first) {
if (context != graph_infer_stack.back().first) {
MS_LOG(EXCEPTION) << "Different context: " << context->func_graph()->ToString() << ", "
<< graph_infer_stack.top().first->func_graph()->ToString();
<< graph_infer_stack.back().first->func_graph()->ToString();
}
graph_infer_stack.pop();
graph_infer_stack.pop_back();
}
void TraceGraphEvalStackPrepare(const TraceGraphEvalStack &graphEvals) {
graph_infer_stack.insert(graph_infer_stack.end(), graphEvals.begin(), graphEvals.end());
}
void TraceEvalCNodeStackPrepare(const TraceCNodeEvalStack &cnodeEvals) {
cnode_debug_stack.insert(cnode_debug_stack.end(), cnodeEvals.begin(), cnodeEvals.end());
}
void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_config) { cnode_debug_stack.push_back(node_config); }
void TraceEvalCNodeLeave() { cnode_debug_stack.pop_back(); }
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
TraceCNodeEvalStack &GetCNodeDebugStack() { return cnode_debug_stack; }
std::stack<std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphEvalStack() {
return graph_infer_stack;
}
TraceGraphEvalStack &GetCurrenGraphEvalStack() { return graph_infer_stack; }
void ClearTraceStack() {
while (!graph_infer_stack.empty()) {
graph_infer_stack.pop();
graph_infer_stack.pop_back();
}
cnode_debug_stack.clear();
}
void GetTraceStackInfo(std::ostringstream &oss) {
TraceGraphEval();
std::ostringstream trace_info;
GetEvalStackInfo(trace_info);
if (trace_info.str().empty()) {
DebugInfoPtr debug_info = TraceManager::GetParseOrResolveDebugInfo();
if (debug_info != nullptr) {
oss << "\n\n# " << trace::GetDebugInfo(debug_info);
}
} else {
oss << trace_info.str();
}
}
// Register trace provider to LogWriter.
struct TraceProviderRegister {
TraceProviderRegister() {
LogWriter::set_trace_provider([](std::ostringstream &oss) {
TraceGraphEval();
std::ostringstream trace_info;
GetEvalStackInfo(trace_info);
if (trace_info.str().empty()) {
DebugInfoPtr debug_info = TraceManager::GetParseOrResolveDebugInfo();
if (debug_info != nullptr) {
oss << "\n\n# " << trace::GetDebugInfo(debug_info);
}
} else {
oss << trace_info.str();
}
});
}
TraceProviderRegister() { LogWriter::set_trace_provider(GetTraceStackInfo); }
~TraceProviderRegister() = default;
} trace_provider_regsiter;

View File

@ -22,6 +22,7 @@
#include <vector>
#include <utility>
#include <stack>
#include <deque>
#include "utils/trace_base.h"
#include "utils/info.h"
@ -32,15 +33,20 @@
namespace mindspore {
namespace trace {
using TraceGraphEvalStack = std::deque<std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>>;
using TraceCNodeEvalStack = std::vector<abstract::AnfNodeConfigPtr>;
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphEval();
void GetEvalStackInfo(std::ostringstream &oss);
void TraceGraphEvalEnter(const abstract::AnalysisContextPtr &context, const abstract::AnfNodeConfigPtr &node);
void TraceGraphEvalLeave(const abstract::AnalysisContextPtr &context);
void TraceGraphEvalStackPrepare(const TraceGraphEvalStack &graphEvals);
void TraceEvalCNodeStackPrepare(const TraceCNodeEvalStack &cnodeEvals);
void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceEvalCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphEvalStack();
TraceCNodeEvalStack &GetCNodeDebugStack();
TraceGraphEvalStack &GetCurrenGraphEvalStack();
void GetTraceStackInfo(std::ostringstream &oss);
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
void ClearTraceStack();
} // namespace trace

View File

@ -138,14 +138,7 @@ bool CheckArgValid(const py::handle &arg) {
std::string GetCompileExceptionInfo() {
std::ostringstream oss;
trace::TraceGraphEval();
trace::GetEvalStackInfo(oss);
if (oss.str().empty()) {
DebugInfoPtr debug_info = TraceManager::GetParseOrResolveDebugInfo();
if (debug_info != nullptr) {
oss << "\n\n# " << trace::GetDebugInfo(debug_info);
}
}
trace::GetTraceStackInfo(oss);
return oss.str();
}

View File

@ -25,11 +25,23 @@ namespace mindspore {
namespace abstract {
AnalysisSchedule AnalysisSchedule::instance_;
void AnalysisSchedule::HandleException() {
void AnalysisSchedule::HandleException(const std::exception &ex) {
// Just record the first exception information.
if (!StaticAnalysisException::Instance().HasException()) {
StaticAnalysisException::Instance().SetException();
MS_LOG(DEBUG) << "Catch the eval exception.";
// If python Exception, record the eval stack.
if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
try {
MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
trace::GetTraceStackInfo(exceptionStream_);
if (!exceptionStream_.str().empty()) {
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream_.str();
}
} catch (const std::exception &e) {
// Ignored.
}
}
}
// Free all the locks. Let all the threads continue to run.
std::lock_guard<std::mutex> lock(lock_);
@ -38,6 +50,7 @@ void AnalysisSchedule::HandleException() {
}
asyncAbstractList_.clear();
}
void AnalysisSchedule::Wait() {
py::gil_scoped_release infer_gil_release;
EnterWaiting();
@ -50,6 +63,7 @@ void AnalysisSchedule::Wait() {
AnalysisResultCacheMgr::GetInstance().Todo();
}
MS_LOG(INFO) << "Infer finished.";
StaticAnalysisException::Instance().CheckException();
}
void AnalysisSchedule::SetNextRunnableImpl() {

View File

@ -44,13 +44,14 @@ class AnalysisSchedule {
static AnalysisSchedule &GetInstance() { return instance_; }
static void SetThreadID(const std::string &caller);
static std::string &GetThreadID();
void HandleException();
void HandleException(const std::exception &ex);
std::string GetExtendException() { return exceptionStream_.str(); }
void Wait();
void Reset() {
activeThreadCount_ = 1;
threadNum_ = 0;
exceptionStream_.clear();
}
void SetNextRunnable() {
@ -62,8 +63,6 @@ class AnalysisSchedule {
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
if (activeThreadCount_ == 0) {
SetNextRunnableImpl();
} else if (activeThreadCount_ < 0) {
MS_LOG(WARNING) << "There is something wrong. active thread count: " << activeThreadCount_;
}
}
@ -104,6 +103,7 @@ class AnalysisSchedule {
std::mutex lock_;
std::condition_variable condition_var_;
std::list<AsyncAbstractPtr> asyncAbstractList_;
std::ostringstream exceptionStream_;
};
template <typename KeyType, typename ValueType, typename CacheType>

View File

@ -512,12 +512,11 @@ EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList
EvalResultPtr result;
try {
result = this->Run(engine, args_conf_list, out_conf);
} catch (const std::exception &e) {
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval " << ToString() << " throw exception.";
AnalysisSchedule::GetInstance().HandleException();
AnalysisSchedule::GetInstance().HandleException(ex);
}
AnalysisSchedule::GetInstance().Wait();
StaticAnalysisException::Instance().CheckException();
return result;
}
} // namespace abstract

View File

@ -143,12 +143,11 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
MS_EXCEPTION_IF_NULL(output_conf);
result.inferred = output_conf->ObtainEvalResult();
result.context = root_context;
} catch (const std::exception &e) {
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception.";
AnalysisSchedule::GetInstance().HandleException();
AnalysisSchedule::GetInstance().HandleException(ex);
}
AnalysisSchedule::GetInstance().Wait();
StaticAnalysisException::Instance().CheckException();
return result;
}
@ -262,7 +261,7 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
if (graph_stack.empty()) {
return;
}
auto top_context = graph_stack.top().first;
auto top_context = graph_stack.back().first;
auto top_context_fg = top_context->func_graph();
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
return;
@ -788,11 +787,17 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
std::string caller, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
AsyncAbstractPtr async_run_flag) {
AnalysisSchedule::SetThreadID(caller);
try {
trace::ClearTraceStack();
AsyncAbstractPtr async_run_flag, const trace::TraceGraphEvalStack &graph_evals,
const trace::TraceCNodeEvalStack &trace_c_node_evals) {
// Set threadID xxx.yyy.zzz for debug info.
if (IS_OUTPUT_ON(DEBUG)) {
AnalysisSchedule::SetThreadID(caller);
}
// Restore trace stack for dump stack when there is exception.
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
trace::TraceGraphEvalStackPrepare(graph_evals);
try {
// Wait for Signal to run
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting.";
(void)async_run_flag->GetResult();
@ -819,11 +824,11 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString();
} catch (const std::exception &e1) {
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
async_result_main->SetResult(abstractErrPtr);
AnalysisSchedule::GetInstance().HandleException();
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
AnalysisSchedule::GetInstance().HandleException(e1);
try {
// Thread number will be drop when thread exits.
AnalysisSchedule::GetInstance().DecreaseThreadCount();
@ -844,7 +849,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
if (eval_result == nullptr) {
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
MS_LOG(DEBUG) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
} else {
return std::make_shared<EvalResult>(eval_result, nullptr);
@ -865,8 +870,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
// Add point to the async thread.
AnalysisSchedule::GetInstance().IncreaseThreadCount();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
auto thread = std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId,
branchAsyncResult, asyncResult_main, asyncRunOrder);
auto thread =
std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId, branchAsyncResult,
asyncResult_main, asyncRunOrder, trace::GetCurrenGraphEvalStack(), trace::GetCNodeDebugStack());
thread.detach();
// Push to list of running loop
asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));