infer opt: determined fun graph first runs & opt exception handle

This commit is contained in:
lanzhineng 2021-06-28 11:17:29 +08:00
parent f026c4ae3e
commit e2adafed3b
5 changed files with 33 additions and 20 deletions

View File

@ -504,11 +504,11 @@ void GetEvalStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Length of analysis information stack is empty.";
return;
}
string file_name = "analyze_fail.dat";
static int fileNumber = 0;
string file_name = "analyze_fail" + std::to_string(fileNumber++) + ".dat";
auto ms_om_path = common::GetEnv("MS_OM_PATH");
if (!ms_om_path.empty()) {
auto path = ms_om_path + "/" + "analyze_fail.dat";
auto path = ms_om_path + "/" + file_name;
auto realpath = Common::GetRealPath(path);
if (!realpath.has_value()) {
MS_EXCEPTION(ValueError) << "Get real path failed. path=" << path;

View File

@ -34,6 +34,7 @@
namespace mindspore {
namespace abstract {
constexpr size_t kInferTimeout = 1800; // 60*30 30min, next pr will change the solution of endless.
constexpr size_t kInferTryTimeout = 3; // 3 microsecond.
class HealthPointMgr {
public:
@ -44,7 +45,7 @@ class HealthPointMgr {
void DropPoint() {
std::unique_lock<std::mutex> lock(lock_);
auto time = std::chrono::microseconds(1);
auto time = std::chrono::microseconds(kInferTryTimeout);
auto cond = condition_var_.wait_for(lock, time, [this] { return point_ > 1; });
if (cond) {
--point_;
@ -61,6 +62,8 @@ class HealthPointMgr {
condition_var_.notify_all();
}
int point() { return point_; }
private:
HealthPointMgr() = default;
int point_{1};
@ -220,8 +223,6 @@ class AsyncResult {
if (ms == 0) {
return result_;
}
// Check if enter endless loop
HealthPointScopedDrop health_point_check;
auto time = std::chrono::microseconds(ms);
// Wait for ms.
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });

View File

@ -578,12 +578,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
#endif
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
static std::mutex fg_lock;
std::lock_guard<std::mutex> infer_lock(fg_lock);
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) {
return;
return false;
}
auto fg = fg_eval->func_graph();
@ -594,7 +594,9 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
return true;
}
return false;
}
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
@ -733,11 +735,20 @@ bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
std::string caller, AsyncAbstractResultPtr async_result_branch,
AsyncAbstractResultPtr async_result_main) {
AsyncAbstractResultPtr async_result_main, bool first, AsyncAbstractResultPtr async_first_Result) {
AnalysisResultCacheMgr::UpdateCaller(caller);
// Wait for the first fg to run
if (!first) {
(void)async_first_Result->GetResult();
}
try {
// Acquire GIL
py::gil_scoped_acquire pyGuard;
// Notify the second fg to go
if (first) {
async_first_Result->JoinResult(std::make_shared<AbstractScalar>(1));
}
trace::ClearTraceStack();
auto result = eval->Run(engine, args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(result->abstract());
@ -754,6 +765,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
<< " value = " << async_result_branch->TryGetResult()->ToString();
} catch (const std::exception &e) {
std::ostringstream oss;
oss << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
trace::GetEvalStackInfo(oss);
if (!oss.str().empty()) {
MS_LOG(ERROR) << oss.str();
@ -786,7 +798,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
<< " Please check the code if there are recursive functions.";
}
if (eval_result->abstract()->isa<AbstractError>()) {
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException();
}
return eval_result;
@ -796,22 +808,23 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
AsyncAbstractResultPtr asyncResult_main = std::make_shared<AsyncAbstractResult>();
AsyncAbstractResultPtr asyncResult0 = std::make_shared<AsyncAbstractResult>();
AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>();
AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>();
SetUndeterminedFlag(evaluators[0]);
SetUndeterminedFlag(evaluators[1]);
bool firstRun = !SetUndeterminedFlag(evaluators[0]);
(void)SetUndeterminedFlag(evaluators[1]);
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult0, asyncResult_main);
out_conf, threadId, asyncResult0, asyncResult_main, firstRun, asyncFirstRunResult);
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult1, asyncResult_main);
out_conf, threadId, asyncResult1, asyncResult_main, !firstRun, asyncFirstRunResult);
// Wait for async threads to finish.
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
@ -824,7 +837,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
<< " Please check the code if there are recursive functions.";
}
if (branchResult->isa<AbstractError>()) {
MS_LOG(ERROR) << "async " << out_conf->node()->ToString() << " threw exception.";
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException();
}
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
@ -848,12 +861,12 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
}
out_specs.push_back(result1);
} else {
if (asyncResult0->TryGetResult()) {
if (asyncResult0->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
<< " value0=" << asyncResult0->GetResult()->ToString();
out_specs.push_back(asyncResult0->GetResult());
}
if (asyncResult1->TryGetResult()) {
if (asyncResult1->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
<< " value1=" << asyncResult1->GetResult()->ToString();
out_specs.push_back(asyncResult1->GetResult());
@ -879,7 +892,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
return conf->ObtainEvalResult()->abstract();
});
for (const auto &eval : evaluators) {
SetUndeterminedFlag(eval);
(void)SetUndeterminedFlag(eval);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.

View File

@ -292,7 +292,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
private:
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
bool *continue_flag);

View File

@ -94,7 +94,6 @@ class StaticAnalysisException {
std::lock_guard<std::mutex> lock(lock_);
if (exception_ptr_ != nullptr) {
auto tmp_exception_ptr = exception_ptr_;
exception_ptr_ = nullptr;
std::rethrow_exception(tmp_exception_ptr);
}
}