infer opt: determined fun graph first runs & opt exception handle
This commit is contained in:
parent
f026c4ae3e
commit
e2adafed3b
|
@ -504,11 +504,11 @@ void GetEvalStackInfo(std::ostringstream &oss) {
|
|||
MS_LOG(INFO) << "Length of analysis information stack is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
string file_name = "analyze_fail.dat";
|
||||
static int fileNumber = 0;
|
||||
string file_name = "analyze_fail" + std::to_string(fileNumber++) + ".dat";
|
||||
auto ms_om_path = common::GetEnv("MS_OM_PATH");
|
||||
if (!ms_om_path.empty()) {
|
||||
auto path = ms_om_path + "/" + "analyze_fail.dat";
|
||||
auto path = ms_om_path + "/" + file_name;
|
||||
auto realpath = Common::GetRealPath(path);
|
||||
if (!realpath.has_value()) {
|
||||
MS_EXCEPTION(ValueError) << "Get real path failed. path=" << path;
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
namespace mindspore {
|
||||
namespace abstract {
|
||||
constexpr size_t kInferTimeout = 1800; // 60*30 30min, next pr will change the solution of endless.
|
||||
constexpr size_t kInferTryTimeout = 3; // 3 microsecond.
|
||||
|
||||
class HealthPointMgr {
|
||||
public:
|
||||
|
@ -44,7 +45,7 @@ class HealthPointMgr {
|
|||
|
||||
void DropPoint() {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto time = std::chrono::microseconds(1);
|
||||
auto time = std::chrono::microseconds(kInferTryTimeout);
|
||||
auto cond = condition_var_.wait_for(lock, time, [this] { return point_ > 1; });
|
||||
if (cond) {
|
||||
--point_;
|
||||
|
@ -61,6 +62,8 @@ class HealthPointMgr {
|
|||
condition_var_.notify_all();
|
||||
}
|
||||
|
||||
int point() { return point_; }
|
||||
|
||||
private:
|
||||
HealthPointMgr() = default;
|
||||
int point_{1};
|
||||
|
@ -220,8 +223,6 @@ class AsyncResult {
|
|||
if (ms == 0) {
|
||||
return result_;
|
||||
}
|
||||
// Check if enter endless loop
|
||||
HealthPointScopedDrop health_point_check;
|
||||
auto time = std::chrono::microseconds(ms);
|
||||
// Wait for ms.
|
||||
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
|
|
|
@ -578,12 +578,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
#endif
|
||||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||
static std::mutex fg_lock;
|
||||
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval == nullptr) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fg = fg_eval->func_graph();
|
||||
|
@ -594,7 +594,9 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
||||
|
@ -733,11 +735,20 @@ bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
|
|||
|
||||
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
||||
std::string caller, AsyncAbstractResultPtr async_result_branch,
|
||||
AsyncAbstractResultPtr async_result_main) {
|
||||
AsyncAbstractResultPtr async_result_main, bool first, AsyncAbstractResultPtr async_first_Result) {
|
||||
AnalysisResultCacheMgr::UpdateCaller(caller);
|
||||
// Wait for the first fg to run
|
||||
if (!first) {
|
||||
(void)async_first_Result->GetResult();
|
||||
}
|
||||
try {
|
||||
// Acquire GIL
|
||||
py::gil_scoped_acquire pyGuard;
|
||||
// Notify the second fg to go
|
||||
if (first) {
|
||||
async_first_Result->JoinResult(std::make_shared<AbstractScalar>(1));
|
||||
}
|
||||
trace::ClearTraceStack();
|
||||
auto result = eval->Run(engine, args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(result->abstract());
|
||||
|
@ -754,6 +765,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
} catch (const std::exception &e) {
|
||||
std::ostringstream oss;
|
||||
oss << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
trace::GetEvalStackInfo(oss);
|
||||
if (!oss.str().empty()) {
|
||||
MS_LOG(ERROR) << oss.str();
|
||||
|
@ -786,7 +798,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
<< " Please check the code if there are recursive functions.";
|
||||
}
|
||||
if (eval_result->abstract()->isa<AbstractError>()) {
|
||||
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
||||
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
return eval_result;
|
||||
|
@ -796,22 +808,23 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
AsyncAbstractResultPtr asyncResult_main = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncResult0 = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>();
|
||||
|
||||
SetUndeterminedFlag(evaluators[0]);
|
||||
SetUndeterminedFlag(evaluators[1]);
|
||||
bool firstRun = !SetUndeterminedFlag(evaluators[0]);
|
||||
(void)SetUndeterminedFlag(evaluators[1]);
|
||||
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult0, asyncResult_main);
|
||||
out_conf, threadId, asyncResult0, asyncResult_main, firstRun, asyncFirstRunResult);
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult1, asyncResult_main);
|
||||
out_conf, threadId, asyncResult1, asyncResult_main, !firstRun, asyncFirstRunResult);
|
||||
|
||||
// Wait for async threads to finish.
|
||||
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
|
||||
|
@ -824,7 +837,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
<< " Please check the code if there are recursive functions.";
|
||||
}
|
||||
if (branchResult->isa<AbstractError>()) {
|
||||
MS_LOG(ERROR) << "async " << out_conf->node()->ToString() << " threw exception.";
|
||||
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
||||
|
@ -848,12 +861,12 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
}
|
||||
out_specs.push_back(result1);
|
||||
} else {
|
||||
if (asyncResult0->TryGetResult()) {
|
||||
if (asyncResult0->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
|
||||
<< " value0=" << asyncResult0->GetResult()->ToString();
|
||||
out_specs.push_back(asyncResult0->GetResult());
|
||||
}
|
||||
if (asyncResult1->TryGetResult()) {
|
||||
if (asyncResult1->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
|
||||
<< " value1=" << asyncResult1->GetResult()->ToString();
|
||||
out_specs.push_back(asyncResult1->GetResult());
|
||||
|
@ -879,7 +892,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
for (const auto &eval : evaluators) {
|
||||
SetUndeterminedFlag(eval);
|
||||
(void)SetUndeterminedFlag(eval);
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
|
|
|
@ -292,7 +292,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
|
||||
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
bool *continue_flag);
|
||||
|
|
|
@ -94,7 +94,6 @@ class StaticAnalysisException {
|
|||
std::lock_guard<std::mutex> lock(lock_);
|
||||
if (exception_ptr_ != nullptr) {
|
||||
auto tmp_exception_ptr = exception_ptr_;
|
||||
exception_ptr_ = nullptr;
|
||||
std::rethrow_exception(tmp_exception_ptr);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue