|
|
|
@ -578,12 +578,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|
|
|
|
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|
|
|
|
static std::mutex fg_lock;
|
|
|
|
|
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
|
|
|
|
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
|
|
|
|
if (fg_eval == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto fg = fg_eval->func_graph();
|
|
|
|
@ -594,7 +594,9 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fg_parent);
|
|
|
|
|
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
|
|
|
|
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
|
|
|
@ -733,11 +735,20 @@ bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
|
|
|
|
|
|
|
|
|
|
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
|
|
|
|
std::string caller, AsyncAbstractResultPtr async_result_branch,
|
|
|
|
|
AsyncAbstractResultPtr async_result_main) {
|
|
|
|
|
AsyncAbstractResultPtr async_result_main, bool first, AsyncAbstractResultPtr async_first_Result) {
|
|
|
|
|
AnalysisResultCacheMgr::UpdateCaller(caller);
|
|
|
|
|
// Wait for the first fg to run
|
|
|
|
|
if (!first) {
|
|
|
|
|
(void)async_first_Result->GetResult();
|
|
|
|
|
}
|
|
|
|
|
try {
|
|
|
|
|
// Acquire GIL
|
|
|
|
|
py::gil_scoped_acquire pyGuard;
|
|
|
|
|
// Notify the second fg to go
|
|
|
|
|
if (first) {
|
|
|
|
|
async_first_Result->JoinResult(std::make_shared<AbstractScalar>(1));
|
|
|
|
|
}
|
|
|
|
|
trace::ClearTraceStack();
|
|
|
|
|
auto result = eval->Run(engine, args_conf_list, out_conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result->abstract());
|
|
|
|
@ -754,6 +765,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|
|
|
|
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
oss << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
|
|
|
|
trace::GetEvalStackInfo(oss);
|
|
|
|
|
if (!oss.str().empty()) {
|
|
|
|
|
MS_LOG(ERROR) << oss.str();
|
|
|
|
@ -786,7 +798,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|
|
|
|
<< " Please check the code if there are recursive functions.";
|
|
|
|
|
}
|
|
|
|
|
if (eval_result->abstract()->isa<AbstractError>()) {
|
|
|
|
|
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
|
|
|
|
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
|
|
|
|
StaticAnalysisException::Instance().CheckException();
|
|
|
|
|
}
|
|
|
|
|
return eval_result;
|
|
|
|
@ -796,22 +808,23 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|
|
|
|
AsyncAbstractResultPtr asyncResult_main = std::make_shared<AsyncAbstractResult>();
|
|
|
|
|
AsyncAbstractResultPtr asyncResult0 = std::make_shared<AsyncAbstractResult>();
|
|
|
|
|
AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>();
|
|
|
|
|
AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>();
|
|
|
|
|
|
|
|
|
|
SetUndeterminedFlag(evaluators[0]);
|
|
|
|
|
SetUndeterminedFlag(evaluators[1]);
|
|
|
|
|
bool firstRun = !SetUndeterminedFlag(evaluators[0]);
|
|
|
|
|
(void)SetUndeterminedFlag(evaluators[1]);
|
|
|
|
|
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
|
|
|
|
// Add point to infer thread
|
|
|
|
|
HealthPointMgr::GetInstance().AddPoint();
|
|
|
|
|
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
|
|
|
|
|
out_conf, threadId, asyncResult0, asyncResult_main);
|
|
|
|
|
out_conf, threadId, asyncResult0, asyncResult_main, firstRun, asyncFirstRunResult);
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
|
|
|
|
|
// Add point to infer thread
|
|
|
|
|
HealthPointMgr::GetInstance().AddPoint();
|
|
|
|
|
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
|
|
|
|
|
out_conf, threadId, asyncResult1, asyncResult_main);
|
|
|
|
|
out_conf, threadId, asyncResult1, asyncResult_main, !firstRun, asyncFirstRunResult);
|
|
|
|
|
|
|
|
|
|
// Wait for async threads to finish.
|
|
|
|
|
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
|
|
|
|
@ -824,7 +837,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|
|
|
|
<< " Please check the code if there are recursive functions.";
|
|
|
|
|
}
|
|
|
|
|
if (branchResult->isa<AbstractError>()) {
|
|
|
|
|
MS_LOG(ERROR) << "async " << out_conf->node()->ToString() << " threw exception.";
|
|
|
|
|
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
|
|
|
|
|
StaticAnalysisException::Instance().CheckException();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
|
|
|
@ -848,12 +861,12 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|
|
|
|
}
|
|
|
|
|
out_specs.push_back(result1);
|
|
|
|
|
} else {
|
|
|
|
|
if (asyncResult0->TryGetResult()) {
|
|
|
|
|
if (asyncResult0->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
|
|
|
|
|
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
|
|
|
|
|
<< " value0=" << asyncResult0->GetResult()->ToString();
|
|
|
|
|
out_specs.push_back(asyncResult0->GetResult());
|
|
|
|
|
}
|
|
|
|
|
if (asyncResult1->TryGetResult()) {
|
|
|
|
|
if (asyncResult1->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
|
|
|
|
|
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
|
|
|
|
|
<< " value1=" << asyncResult1->GetResult()->ToString();
|
|
|
|
|
out_specs.push_back(asyncResult1->GetResult());
|
|
|
|
@ -879,7 +892,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|
|
|
|
return conf->ObtainEvalResult()->abstract();
|
|
|
|
|
});
|
|
|
|
|
for (const auto &eval : evaluators) {
|
|
|
|
|
SetUndeterminedFlag(eval);
|
|
|
|
|
(void)SetUndeterminedFlag(eval);
|
|
|
|
|
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
|
|
|
|
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
|
|
|
|
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
|
|
|
|