suport more evaluators and fix recrusive_fun testcase

This commit is contained in:
lanzhineng 2021-07-06 11:16:38 +08:00
parent 83bcf936b6
commit b706db3119
4 changed files with 63 additions and 87 deletions

View File

@ -76,10 +76,9 @@ void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future0, std::future<void> &&future1) {
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future) {
std::lock_guard<std::mutex> lock(lock_);
waiting_.emplace_back(std::move(future0));
waiting_.emplace_back(std::move(future1));
waiting_.emplace_back(std::move(future));
}
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {

View File

@ -287,7 +287,7 @@ class AnalysisResultCacheMgr {
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
// Wait for async Eval(conf) to finish.
void Wait();
void PushTowait(std::future<void> &&future0, std::future<void> &&future1);
void PushTowait(std::future<void> &&future);
void PushTodo(const AnfNodeConfigPtr &conf);
void Todo();
static void UpdateCaller(const std::string &caller);

View File

@ -775,7 +775,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
if (abstract->isa<AbstractFunction>()) {
return true;
}
@ -845,109 +845,80 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
} else {
if (eval_result->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out."
<< " Please check the code if there are recursive functions.";
}
if (eval_result->isa<AbstractError>()) {
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
if (eval_result->isa<AbstractTimeOut>() || eval_result->isa<AbstractError>()) {
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException();
}
return std::make_shared<EvalResult>(eval_result, nullptr);
}
// Eval result of the branches and main.
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr asyncResult0 = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr asyncResult1 = std::make_shared<AsyncAbstract>();
// Control which thread to run.
AsyncAbstractPtr asyncRun0 = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr asyncRun1 = std::make_shared<AsyncAbstract>();
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto possible_parent_fg = out_conf->node()->func_graph();
SetUndeterminedFlag(evaluators[0], possible_parent_fg);
SetUndeterminedFlag(evaluators[1], possible_parent_fg);
// Eval result of the branches and main.
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
std::vector<AsyncAbstractPtr> branchAsyncResults;
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult0, asyncResult_main, asyncRun0);
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult1, asyncResult_main, asyncRun1);
// Wait for async threads to finish.
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
// Push to list of running loop
asyncRun0->JoinResult(std::make_shared<AbstractScalar>(0));
asyncRun1->JoinResult(std::make_shared<AbstractScalar>(0));
// Run order
HealthPointMgr::GetInstance().Add2Schedule(asyncRun0); // First order
HealthPointMgr::GetInstance().Add2Schedule(asyncRun1); // Second order
for (auto &evaluator : evaluators) {
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
// Control the order to run.
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
SetUndeterminedFlag(evaluator, possible_parent_fg);
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
// Add point to infer thread
HealthPointMgr::GetInstance().AddPoint();
auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
// Wait for async threads to finish.
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future));
// Push to list of running loop
asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order
branchAsyncResults.emplace_back(std::move(branchAsyncResult));
}
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
<< " or " << evaluators[1]->ToString();
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
auto branchResult = asyncResult_main->GetResult();
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) {
auto firstResult = asyncResult_main->GetResult();
if (firstResult == nullptr || firstResult->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
<< " Please check the code if there are recursive functions.";
}
if (branchResult->isa<AbstractError>()) {
if (firstResult->isa<AbstractError>()) {
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
StaticAnalysisException::Instance().CheckException();
}
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
<< branchResult->ToString();
<< firstResult->ToString();
AbstractBasePtrList out_specs;
if (NeedWaitForTwoBranches(branchResult)) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString();
// The asyncRun0 will eval asyncResult0
HealthPointMgr::GetInstance().Add2Schedule(asyncResult0);
auto result0 = asyncResult0->GetResult();
if (result0 == nullptr || result0->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
<< " Please check the code if there is recursive function.";
size_t len = evaluators.size();
if (NeedWaitForBranches(firstResult)) {
for (size_t i = 0; i < len; ++i) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]);
auto result = branchAsyncResults[i]->GetResult();
if (result == nullptr || result->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
<< " Please check the code if there is recursive function.";
}
out_specs.push_back(result);
}
out_specs.push_back(result0);
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString();
// The asyncRun1 will eval asyncResult1
HealthPointMgr::GetInstance().Add2Schedule(asyncResult1);
auto result1 = asyncResult1->GetResult();
if (result1 == nullptr || result1->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << " is time out."
<< " Please check the code if there is recursive function.";
}
out_specs.push_back(result1);
} else {
// Next time to get the result of branches.
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
(void)asyncResult_main->GetResult();
// Don't use GetResult
auto value0 = asyncResult0->TryGetResult();
if (value0) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
<< " value0=" << value0->ToString();
out_specs.push_back(value0);
}
// Don't use GetResult
auto value1 = asyncResult1->TryGetResult();
if (value1) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
<< " value1=" << value1->ToString();
out_specs.push_back(value1);
for (size_t i = 0; i < len; ++i) {
// Not wait to get the result of branch.
auto result = branchAsyncResults[i]->TryGetResult();
if (result) {
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
<< " result =" << result->ToString();
out_specs.push_back(result);
}
}
}
return ProcessEvalResults(out_specs, out_conf->node());

View File

@ -23,13 +23,15 @@ ONE = Tensor([1], mstype.int32)
@ms_function
def f(x):
y = f(x - 4)
y = ZERO
if x < 0:
y = f(x - 3)
elif x < 3:
y = x * f(x - 1)
elif x >= 3:
elif x < 5:
y = x * f(x - 2)
else:
y = f(x - 4)
z = y + 1
return z
@ -41,8 +43,10 @@ def fr(x):
y = ONE
elif x < 3:
y = x * fr(x - 1)
elif x >= 3:
elif x < 5:
y = x * fr(x - 2)
else:
y = fr(x - 4)
z = y + 1
return z
@ -50,18 +54,20 @@ def fr(x):
def test_endless():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32)
f(x)
with pytest.raises(ValueError):
print("endless.")
try:
f(x)
except RuntimeError as e:
assert 'endless loop' in str(e)
@pytest.mark.skip(reason="backend is not supported yet")
def test_recrusive_fun():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32)
ret = fr(x)
expect = Tensor([36], mstype.int32)
expect = Tensor([3], mstype.int32)
assert ret == expect
if __name__ == "__main__":
test_recrusive_fun()
test_endless()