clear exception_ptr

This commit is contained in:
lanzhineng 2021-07-07 21:45:09 +08:00
parent c6017df10d
commit 34aae1bef4
5 changed files with 172 additions and 101 deletions

View File

@ -15,26 +15,39 @@
*/ */
#include "pipeline/jit/static_analysis/async_eval_result.h" #include "pipeline/jit/static_analysis/async_eval_result.h"
#include <chrono> #include <debug/trace.h>
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "debug/common.h" #include "debug/common.h"
#include "pipeline/jit/base.h" #include "pipeline/jit/base.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "abstract/utils.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
HealthPointMgr HealthPointMgr::instance_; HealthPointMgr HealthPointMgr::instance_;
void HealthPointMgr::Clear() { point_ = 1; } void HealthPointMgr::Clear() {
MS_LOG(DEBUG) << " Point = " << point_;
point_ = 1;
}
void HealthPointMgr::HandleException() { void HealthPointMgr::HandleException() {
// Just record the first exception information.
if (!StaticAnalysisException::Instance().HasException()) {
std::ostringstream oss;
trace::GetEvalStackInfo(oss);
if (!oss.str().empty()) {
MS_LOG(ERROR) << oss.str();
}
StaticAnalysisException::Instance().SetException();
}
// Free all the locks. Let all the threads continue to run.
std::lock_guard<std::recursive_mutex> lock(lock_); std::lock_guard<std::recursive_mutex> lock(lock_);
for (auto &item : asyncAbstractList_) { for (auto &item : asyncAbstractList_) {
item->SetRunable(); item->SetRunable();
} }
asyncAbstractList_.clear(); asyncAbstractList_.clear();
} }
void HealthPointMgr::SetNextRunable() { void HealthPointMgr::SetNextRunable() {
std::lock_guard<std::recursive_mutex> lock(lock_); std::lock_guard<std::recursive_mutex> lock(lock_);
if (asyncAbstractList_.empty()) { if (asyncAbstractList_.empty()) {
@ -46,9 +59,9 @@ void HealthPointMgr::SetNextRunable() {
[](const auto &item) { return item->HasResult(); }); [](const auto &item) { return item->HasResult(); });
if (it == asyncAbstractList_.end()) { if (it == asyncAbstractList_.end()) {
// Enter endless loop if there is not ready result. // Enter endless loop if there is not ready result.
MS_LOG(EXCEPTION) << "Enter endless loop. Please check the code. point = " << HealthPointMgr::GetInstance().point() MS_LOG(EXCEPTION) << "Enter endless loop. There is not more node that can been evaluated. Please check the code.";
<< " Called times : " << asyncAbstractList_.front()->count();
} }
// Push back the not ready async.
asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it); asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it);
asyncAbstractList_.erase(asyncAbstractList_.begin(), it); asyncAbstractList_.erase(asyncAbstractList_.begin(), it);
@ -65,8 +78,10 @@ void AnalysisResultCacheMgr::Clear() {
cache_.clear(); cache_.clear();
switch_cache_.clear(); switch_cache_.clear();
todo_.clear(); todo_.clear();
waiting_.clear();
} }
// The thread id format is XXXX.YYYY.ZZZZ
thread_local static std::string local_threadid; thread_local static std::string local_threadid;
void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) { void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
std::ostringstream buffer; std::ostringstream buffer;
@ -76,7 +91,7 @@ void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; } std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future) { void AnalysisResultCacheMgr::PushToWait(std::future<void> &&future) {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<std::mutex> lock(lock_);
waiting_.emplace_back(std::move(future)); waiting_.emplace_back(std::move(future));
} }
@ -94,6 +109,7 @@ void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
switch_cache_.set(conf, async_eval_result); switch_cache_.set(conf, async_eval_result);
} }
} }
AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) { AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) {
// don't call lock_.lock(). switch_cache is protected. and it waits for result. // don't call lock_.lock(). switch_cache is protected. and it waits for result.
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf); AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
@ -125,7 +141,7 @@ AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &c
return nullptr; return nullptr;
} }
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr arg) { void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
if (arg == nullptr) { if (arg == nullptr) {
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr"; MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
@ -159,6 +175,14 @@ void AnalysisResultCacheMgr::Todo() {
while (!todo_.empty()) { while (!todo_.empty()) {
AnfNodeConfigPtr conf = todo_.front(); AnfNodeConfigPtr conf = todo_.front();
todo_.pop_front(); todo_.pop_front();
if (GetValue(conf) == nullptr) {
MS_LOG(WARNING) << conf->node()->ToString() << " not in globleCache";
continue;
}
if (TryGetSwitchValue(conf) == nullptr) {
MS_LOG(WARNING) << conf->node()->ToString() << " not in switchCache";
continue;
}
if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) { if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) {
MS_LOG(WARNING) << " Switch Value is not eq. " MS_LOG(WARNING) << " Switch Value is not eq. "
<< " switchCache: " << TryGetSwitchValue(conf)->ToString() << " switchCache: " << TryGetSwitchValue(conf)->ToString()
@ -172,7 +196,6 @@ void AnalysisResultCacheMgr::Wait() {
// Check all the async to finish. // Check all the async to finish.
HealthPointScopedDrop hpCheck; HealthPointScopedDrop hpCheck;
while (true) { while (true) {
StaticAnalysisException::Instance().CheckException();
lock_.lock(); lock_.lock();
if (waiting_.empty()) { if (waiting_.empty()) {
lock_.unlock(); lock_.unlock();
@ -188,6 +211,7 @@ void AnalysisResultCacheMgr::Wait() {
if (IS_OUTPUT_ON(DEBUG)) { if (IS_OUTPUT_ON(DEBUG)) {
Todo(); Todo();
} }
MS_LOG(INFO) << "Infer finished.";
} }
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) { std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {

View File

@ -51,7 +51,7 @@ class HealthPointMgr {
if (point_ == 0) { if (point_ == 0) {
SetNextRunable(); SetNextRunable();
} else if (point_ < 0) { } else if (point_ < 0) {
MS_LOG(EXCEPTION) << "There is something wrong."; MS_LOG(WARNING) << "There is something wrong. point = " << point_;
} }
} }
@ -66,7 +66,7 @@ class HealthPointMgr {
++point_; ++point_;
} }
int point() { return point_; } int point() const { return point_; }
void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) { void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) {
std::lock_guard<std::recursive_mutex> lock(lock_); std::lock_guard<std::recursive_mutex> lock(lock_);
@ -187,48 +187,47 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
~AsyncAbstract() = default; ~AsyncAbstract() = default;
// Wait // Wait
AbstractBasePtr GetResult() { AbstractBasePtr GetResult() {
StaticAnalysisException::Instance().CheckException(); static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance();
static StaticAnalysisException &exceptionMgr = StaticAnalysisException::Instance();
exceptionMgr.CheckException();
std::unique_lock<std::mutex> lock(lock_); std::unique_lock<std::mutex> lock(lock_);
while (true) { while (true) {
++count_; ++count_;
// The point should be dropped if it can't run. It will be added when it can run. // The point should be dropped if it can't run. It will be added when it can run.
bool hasDropPoint = false; bool hasDropPoint = false;
if (!runable_) { if (!runable_) {
HealthPointMgr::GetInstance().DropPoint(); healthPointMgr.DropPoint();
hasDropPoint = true; hasDropPoint = true;
} }
MS_LOG(DEBUG) << this << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0); MS_LOG(DEBUG) << this << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
condition_var_.wait(lock, [this] { return runable_; }); condition_var_.wait(lock, [this] { return runable_; });
if (hasDropPoint) {
healthPointMgr.AddPoint();
}
MS_LOG(DEBUG) << this << " continue runable: " << runable_ << " result: " << (result_ ? result_.get() : 0); MS_LOG(DEBUG) << this << " continue runable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
StaticAnalysisException::Instance().CheckException();
exceptionMgr.CheckException();
runable_ = false; runable_ = false;
if (result_ != nullptr) { if (result_ != nullptr) {
if (hasDropPoint) {
HealthPointMgr::GetInstance().AddPoint();
}
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0); MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0);
return result_; return result_;
} }
// Push to list // Push to list
HealthPointMgr::GetInstance().Add2Schedule(shared_from_this()); healthPointMgr.Add2Schedule(shared_from_this());
if (hasDropPoint) {
HealthPointMgr::GetInstance().AddPoint();
}
// Notify the next asyncAbastract to run. // Notify the next asyncAbastract to run.
HealthPointMgr::GetInstance().SetNextRunable(); healthPointMgr.SetNextRunable();
MS_LOG(DEBUG) << this << " SetNextRunable " MS_LOG(DEBUG) << this << " SetNextRunable "
<< " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0) << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0)
<< " point:" << HealthPointMgr::GetInstance().point(); << " point:" << healthPointMgr.point();
} }
return nullptr;
} }
void SetRunable() { void SetRunable() {
MS_LOG(DEBUG) << this << " Runable."; MS_LOG(DEBUG) << this << " Runable.";
runable_ = true; runable_ = true;
condition_var_.notify_one(); condition_var_.notify_one();
} }
int count() { return count_; } int count() const { return count_; }
bool HasResult() { return result_ != nullptr; } bool HasResult() { return result_ != nullptr; }
// Not wait // Not wait
@ -287,7 +286,7 @@ class AnalysisResultCacheMgr {
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); } inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
// Wait for async Eval(conf) to finish. // Wait for async Eval(conf) to finish.
void Wait(); void Wait();
void PushTowait(std::future<void> &&future); void PushToWait(std::future<void> &&future);
void PushTodo(const AnfNodeConfigPtr &conf); void PushTodo(const AnfNodeConfigPtr &conf);
void Todo(); void Todo();
static void UpdateCaller(const std::string &caller); static void UpdateCaller(const std::string &caller);
@ -295,7 +294,7 @@ class AnalysisResultCacheMgr {
void InitSwitchValue(const AnfNodeConfigPtr &conf); void InitSwitchValue(const AnfNodeConfigPtr &conf);
AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf); AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf); AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr vale); void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale);
private: private:
using AnalysisConfigAsyncResultMap = using AnalysisConfigAsyncResultMap =

View File

@ -509,9 +509,15 @@ EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrLis
} }
EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) { const AnfNodeConfigPtr &out_conf) {
auto result = this->Run(engine, args_conf_list, out_conf); EvalResultPtr result;
try {
result = this->Run(engine, args_conf_list, out_conf);
} catch (const std::exception &e) {
MS_LOG(WARNING) << "Eval " << ToString() << " throw exception.";
HealthPointMgr::GetInstance().HandleException();
}
AnalysisResultCacheMgr::GetInstance().Wait(); AnalysisResultCacheMgr::GetInstance().Wait();
StaticAnalysisException::Instance().CheckException();
return result; return result;
} }
} // namespace abstract } // namespace abstract

View File

@ -120,32 +120,37 @@ bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeCon
AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
StaticAnalysisException::Instance().ClearException(); StaticAnalysisException::Instance().ClearException();
HealthPointMgr::GetInstance().Clear(); HealthPointMgr::GetInstance().Clear();
ConfigPtrList args_conf_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
MS_EXCEPTION_IF_NULL(func_graph_manager_);
func_graph_manager_->AddFuncGraph(func_graph);
root_func_graph_ = func_graph;
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
// Running the analyzer.
ResetFunctionCallDepth();
ResetStackFrameDepth();
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
MS_EXCEPTION_IF_NULL(root_context);
MS_EXCEPTION_IF_NULL(root_context->func_graph());
AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
AnalysisResult result; AnalysisResult result;
MS_EXCEPTION_IF_NULL(output_conf); try {
result.inferred = output_conf->ObtainEvalResult(); ConfigPtrList args_conf_list;
result.context = root_context; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
MS_EXCEPTION_IF_NULL(func_graph_manager_);
func_graph_manager_->AddFuncGraph(func_graph);
root_func_graph_ = func_graph;
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
// Running the analyzer.
ResetFunctionCallDepth();
ResetStackFrameDepth();
AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list);
MS_EXCEPTION_IF_NULL(root_context);
MS_EXCEPTION_IF_NULL(root_context->func_graph());
AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";
MS_EXCEPTION_IF_NULL(output_conf);
result.inferred = output_conf->ObtainEvalResult();
result.context = root_context;
} catch (const std::exception &e) {
MS_LOG(WARNING) << "Eval " << func_graph->ToString() << " threw exception.";
HealthPointMgr::GetInstance().HandleException();
}
AnalysisResultCacheMgr::GetInstance().Wait(); AnalysisResultCacheMgr::GetInstance().Wait();
StaticAnalysisException::Instance().CheckException();
return result; return result;
} }
@ -374,6 +379,8 @@ void AnalysisEngine::ClearEvaluatorCache() {
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr()); MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
evaluator->evaluator_cache_mgr()->Clear(); evaluator->evaluator_cache_mgr()->Clear();
} }
// Release Exception to avoid hup at exit.
StaticAnalysisException::Instance().ClearException();
} }
void AnalysisEngine::Clear() { void AnalysisEngine::Clear() {
@ -789,40 +796,38 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
AsyncAbstractPtr async_run_flag) { AsyncAbstractPtr async_run_flag) {
AnalysisResultCacheMgr::UpdateCaller(caller); AnalysisResultCacheMgr::UpdateCaller(caller);
try { try {
trace::ClearTraceStack();
// Wait for Signal to run // Wait for Signal to run
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting."; MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting.";
(void)async_run_flag->GetResult(); (void)async_run_flag->GetResult();
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running."; MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running.";
// Acquire GIL // Acquire GIL for eval to callback python.
py::gil_scoped_acquire pyGuard; EvalResultPtr result;
trace::ClearTraceStack(); {
auto result = eval->Run(engine, args_conf_list, out_conf); py::gil_scoped_acquire pyGuard;
result = eval->Run(engine, args_conf_list, out_conf);
}
MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(result->abstract()); MS_EXCEPTION_IF_NULL(result->abstract());
// Broaden the result of switch(c,t,f)() // Broaden the result of switch(c,t,f)()
auto broadAbstract = result->abstract()->Broaden(); auto broadAbstract = result->abstract()->Broaden();
// Let main thread to continue. // Notify the thread of waiting for switch node and the main thread to continue.
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract); AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract);
async_result_branch->JoinResult(broadAbstract); async_result_branch->JoinResult(broadAbstract);
async_result_main->JoinResult(broadAbstract); async_result_main->JoinResult(broadAbstract);
// Health Point will be drop when thread exits.
HealthPointMgr::GetInstance().DropPoint();
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString() MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
<< " asyncResult address = " << async_result_branch.get() << " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString(); << " value = " << async_result_branch->TryGetResult()->ToString();
// Decrease infer thread.
HealthPointMgr::GetInstance().DropPoint();
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::ostringstream oss; MS_LOG(WARNING) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
oss << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception."; auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
trace::GetEvalStackInfo(oss);
if (!oss.str().empty()) {
MS_LOG(ERROR) << oss.str();
}
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>(oss.str()), out_conf->node());
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr); AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
async_result_main->JoinResult(abstractErrPtr); async_result_main->JoinResult(abstractErrPtr);
StaticAnalysisException::Instance().SetException();
HealthPointMgr::GetInstance().HandleException(); HealthPointMgr::GetInstance().HandleException();
} }
} }
@ -830,61 +835,54 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators, EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) { const ConfigPtrList &args_conf_list) {
// Release GIL; MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance();
static AnalysisResultCacheMgr &resultCacheMgr = AnalysisResultCacheMgr::GetInstance();
// Release GIL for C++
py::gil_scoped_release infer_gil_release; py::gil_scoped_release infer_gil_release;
// Wait for the last switch node to finish. // Wait for the last switch node to finish.
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString(); MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf); auto eval_result = resultCacheMgr.GetSwitchValue(out_conf);
if (eval_result == nullptr) { if (eval_result == nullptr) {
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString(); MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf); resultCacheMgr.InitSwitchValue(out_conf);
} else { } else {
if (eval_result->isa<AbstractTimeOut>() || eval_result->isa<AbstractError>()) {
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException();
}
return std::make_shared<EvalResult>(eval_result, nullptr); return std::make_shared<EvalResult>(eval_result, nullptr);
} }
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto possible_parent_fg = out_conf->node()->func_graph(); auto possible_parent_fg = out_conf->node()->func_graph();
// Eval result of the branches and main.
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
std::string threadId = AnalysisResultCacheMgr::GetThreadid(); std::string threadId = AnalysisResultCacheMgr::GetThreadid();
// Eval result of the main.
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
// Eval result of the branches
std::vector<AsyncAbstractPtr> branchAsyncResults; std::vector<AsyncAbstractPtr> branchAsyncResults;
for (auto &evaluator : evaluators) { for (auto &evaluator : evaluators) {
SetUndeterminedFlag(evaluator, possible_parent_fg);
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>(); AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
// Control the order to run. // Control the order to run.
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>(); AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
SetUndeterminedFlag(evaluator, possible_parent_fg); // Add point to the async thread.
healthPointMgr.AddPoint();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString(); MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
threadId, branchAsyncResult, asyncResult_main, asyncRunOrder); threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
// Wait for async threads to finish. // Wait for async threads to finish.
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future)); resultCacheMgr.PushToWait(std::move(future));
// Push to list of running loop // Push to list of running loop
asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1)); asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order healthPointMgr.Add2Schedule(asyncRunOrder); // Activate order
branchAsyncResults.emplace_back(std::move(branchAsyncResult)); branchAsyncResults.emplace_back(std::move(branchAsyncResult));
} }
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString() MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
<< " or " << evaluators[1]->ToString(); << " or " << evaluators[1]->ToString() << "...";
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order healthPointMgr.Add2Schedule(asyncResult_main); // Third order
auto firstResult = asyncResult_main->GetResult(); auto firstResult = asyncResult_main->GetResult();
if (firstResult == nullptr || firstResult->isa<AbstractTimeOut>()) { MS_EXCEPTION_IF_NULL(firstResult);
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
<< " Please check the code if there are recursive functions.";
}
if (firstResult->isa<AbstractError>()) {
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException();
}
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = " MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
<< firstResult->ToString(); << firstResult->ToString();
@ -893,19 +891,15 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
if (NeedWaitForBranches(firstResult)) { if (NeedWaitForBranches(firstResult)) {
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString(); MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]); healthPointMgr.Add2Schedule(branchAsyncResults[i]);
auto result = branchAsyncResults[i]->GetResult(); auto result = branchAsyncResults[i]->GetResult();
if (result == nullptr || result->isa<AbstractTimeOut>()) { MS_EXCEPTION_IF_NULL(result);
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
<< " Please check the code if there is recursive function.";
}
out_specs.push_back(result); out_specs.push_back(result);
} }
} else { } else {
// Next time to get the result of branches. // Give one more chance to wait for the result of the branches.
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); healthPointMgr.Add2Schedule(asyncResult_main);
(void)asyncResult_main->GetResult(); (void)asyncResult_main->GetResult();
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
// Not wait to get the result of branch. // Not wait to get the result of branch.
auto result = branchAsyncResults[i]->TryGetResult(); auto result = branchAsyncResults[i]->TryGetResult();

View File

@ -51,6 +51,38 @@ def fr(x):
return z return z
@ms_function
def f_pythonerr(x):
if x > 0:
return f_pythonerr(x - 1)
return NOT_DEF
def test_python_error():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32)
try:
f_pythonerr(x)
except NameError as e:
assert 'not defined' in str(e)
@ms_function
def f_recrusive_endless(x):
if x > 0:
return f_recrusive_endless(x - 1)
return f_recrusive_endless(x + 1)
def test_recrusive_endless():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32)
try:
f_recrusive_endless(x)
except RuntimeError as e:
assert 'endless loop' in str(e)
def test_endless(): def test_endless():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32) x = Tensor([5], mstype.int32)
@ -60,6 +92,22 @@ def test_endless():
assert 'endless loop' in str(e) assert 'endless loop' in str(e)
@ms_function
def f_ok(x):
if x > 0:
return f_ok(x - 1) + 1
return ONE
@pytest.mark.skip(reason="backend is not supported yet")
def test_f_ok():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([3], mstype.int32)
ret = f_ok(x)
expect = Tensor([4], mstype.int32)
assert ret == expect
@pytest.mark.skip(reason="backend is not supported yet") @pytest.mark.skip(reason="backend is not supported yet")
def test_recrusive_fun(): def test_recrusive_fun():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)