suport more evaluators and fix recrusive_fun testcase

This commit is contained in:
lanzhineng 2021-07-06 11:16:38 +08:00
parent 83bcf936b6
commit b706db3119
4 changed files with 63 additions and 87 deletions

View File

@ -76,10 +76,9 @@ void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; } std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future0, std::future<void> &&future1) { void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future) {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<std::mutex> lock(lock_);
waiting_.emplace_back(std::move(future0)); waiting_.emplace_back(std::move(future));
waiting_.emplace_back(std::move(future1));
} }
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) { void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {

View File

@ -287,7 +287,7 @@ class AnalysisResultCacheMgr {
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); } inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
// Wait for async Eval(conf) to finish. // Wait for async Eval(conf) to finish.
void Wait(); void Wait();
void PushTowait(std::future<void> &&future0, std::future<void> &&future1); void PushTowait(std::future<void> &&future);
void PushTodo(const AnfNodeConfigPtr &conf); void PushTodo(const AnfNodeConfigPtr &conf);
void Todo(); void Todo();
static void UpdateCaller(const std::string &caller); static void UpdateCaller(const std::string &caller);

View File

@ -775,7 +775,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
} }
bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) { bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
if (abstract->isa<AbstractFunction>()) { if (abstract->isa<AbstractFunction>()) {
return true; return true;
} }
@ -845,109 +845,80 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString(); MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf); AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
} else { } else {
if (eval_result->isa<AbstractTimeOut>()) { if (eval_result->isa<AbstractTimeOut>() || eval_result->isa<AbstractError>()) {
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out." MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
<< " Please check the code if there are recursive functions.";
}
if (eval_result->isa<AbstractError>()) {
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException(); StaticAnalysisException::Instance().CheckException();
} }
return std::make_shared<EvalResult>(eval_result, nullptr); return std::make_shared<EvalResult>(eval_result, nullptr);
} }
// Eval result of the branches and main.
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr asyncResult0 = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr asyncResult1 = std::make_shared<AsyncAbstract>();
// Control which thread to run.
AsyncAbstractPtr asyncRun0 = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr asyncRun1 = std::make_shared<AsyncAbstract>();
MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node()); MS_EXCEPTION_IF_NULL(out_conf->node());
auto possible_parent_fg = out_conf->node()->func_graph(); auto possible_parent_fg = out_conf->node()->func_graph();
SetUndeterminedFlag(evaluators[0], possible_parent_fg); // Eval result of the branches and main.
SetUndeterminedFlag(evaluators[1], possible_parent_fg); AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
std::string threadId = AnalysisResultCacheMgr::GetThreadid(); std::string threadId = AnalysisResultCacheMgr::GetThreadid();
std::vector<AsyncAbstractPtr> branchAsyncResults;
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString(); for (auto &evaluator : evaluators) {
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
// Control the order to run.
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
SetUndeterminedFlag(evaluator, possible_parent_fg);
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
// Add point to infer thread // Add point to infer thread
HealthPointMgr::GetInstance().AddPoint(); HealthPointMgr::GetInstance().AddPoint();
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list, auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
out_conf, threadId, asyncResult0, asyncResult_main, asyncRun0); threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult1, asyncResult_main, asyncRun1);
// Wait for async threads to finish. // Wait for async threads to finish.
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1)); AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future));
// Push to list of running loop // Push to list of running loop
asyncRun0->JoinResult(std::make_shared<AbstractScalar>(0)); asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
asyncRun1->JoinResult(std::make_shared<AbstractScalar>(0)); HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order
// Run order branchAsyncResults.emplace_back(std::move(branchAsyncResult));
HealthPointMgr::GetInstance().Add2Schedule(asyncRun0); // First order }
HealthPointMgr::GetInstance().Add2Schedule(asyncRun1); // Second order
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString() MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
<< " or " << evaluators[1]->ToString(); << " or " << evaluators[1]->ToString();
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
auto branchResult = asyncResult_main->GetResult(); auto firstResult = asyncResult_main->GetResult();
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) { if (firstResult == nullptr || firstResult->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString() MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
<< " Please check the code if there are recursive functions."; << " Please check the code if there are recursive functions.";
} }
if (branchResult->isa<AbstractError>()) { if (firstResult->isa<AbstractError>()) {
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception."; MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException(); StaticAnalysisException::Instance().CheckException();
} }
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = " MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
<< branchResult->ToString(); << firstResult->ToString();
AbstractBasePtrList out_specs; AbstractBasePtrList out_specs;
if (NeedWaitForTwoBranches(branchResult)) { size_t len = evaluators.size();
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString(); if (NeedWaitForBranches(firstResult)) {
// The asyncRun0 will eval asyncResult0 for (size_t i = 0; i < len; ++i) {
HealthPointMgr::GetInstance().Add2Schedule(asyncResult0); MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
auto result0 = asyncResult0->GetResult(); HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]);
if (result0 == nullptr || result0->isa<AbstractTimeOut>()) { auto result = branchAsyncResults[i]->GetResult();
if (result == nullptr || result->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out." MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
<< " Please check the code if there is recursive function."; << " Please check the code if there is recursive function.";
} }
out_specs.push_back(result0); out_specs.push_back(result);
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString();
// The asyncRun1 will eval asyncResult1
HealthPointMgr::GetInstance().Add2Schedule(asyncResult1);
auto result1 = asyncResult1->GetResult();
if (result1 == nullptr || result1->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << " is time out."
<< " Please check the code if there is recursive function.";
} }
out_specs.push_back(result1);
} else { } else {
// Next time to get the result of branches. // Next time to get the result of branches.
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
(void)asyncResult_main->GetResult(); (void)asyncResult_main->GetResult();
// Don't use GetResult for (size_t i = 0; i < len; ++i) {
auto value0 = asyncResult0->TryGetResult(); // Not wait to get the result of branch.
if (value0) { auto result = branchAsyncResults[i]->TryGetResult();
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString() if (result) {
<< " value0=" << value0->ToString(); MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
out_specs.push_back(value0); << " result =" << result->ToString();
out_specs.push_back(result);
} }
// Don't use GetResult
auto value1 = asyncResult1->TryGetResult();
if (value1) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
<< " value1=" << value1->ToString();
out_specs.push_back(value1);
} }
} }
return ProcessEvalResults(out_specs, out_conf->node()); return ProcessEvalResults(out_specs, out_conf->node());

View File

@ -23,13 +23,15 @@ ONE = Tensor([1], mstype.int32)
@ms_function @ms_function
def f(x): def f(x):
y = f(x - 4) y = ZERO
if x < 0: if x < 0:
y = f(x - 3) y = f(x - 3)
elif x < 3: elif x < 3:
y = x * f(x - 1) y = x * f(x - 1)
elif x >= 3: elif x < 5:
y = x * f(x - 2) y = x * f(x - 2)
else:
y = f(x - 4)
z = y + 1 z = y + 1
return z return z
@ -41,8 +43,10 @@ def fr(x):
y = ONE y = ONE
elif x < 3: elif x < 3:
y = x * fr(x - 1) y = x * fr(x - 1)
elif x >= 3: elif x < 5:
y = x * fr(x - 2) y = x * fr(x - 2)
else:
y = fr(x - 4)
z = y + 1 z = y + 1
return z return z
@ -50,18 +54,20 @@ def fr(x):
def test_endless(): def test_endless():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32) x = Tensor([5], mstype.int32)
try:
f(x) f(x)
with pytest.raises(ValueError): except RuntimeError as e:
print("endless.") assert 'endless loop' in str(e)
@pytest.mark.skip(reason="backend is not supported yet")
def test_recrusive_fun(): def test_recrusive_fun():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32) x = Tensor([5], mstype.int32)
ret = fr(x) ret = fr(x)
expect = Tensor([36], mstype.int32) expect = Tensor([3], mstype.int32)
assert ret == expect assert ret == expect
if __name__ == "__main__": if __name__ == "__main__":
test_recrusive_fun() test_endless()