suport more evaluators and fix recrusive_fun testcase
This commit is contained in:
parent
83bcf936b6
commit
b706db3119
|
@ -76,10 +76,9 @@ void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
|
|||
|
||||
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
|
||||
|
||||
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future0, std::future<void> &&future1) {
|
||||
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future) {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
waiting_.emplace_back(std::move(future0));
|
||||
waiting_.emplace_back(std::move(future1));
|
||||
waiting_.emplace_back(std::move(future));
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
|
||||
|
|
|
@ -287,7 +287,7 @@ class AnalysisResultCacheMgr {
|
|||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||
// Wait for async Eval(conf) to finish.
|
||||
void Wait();
|
||||
void PushTowait(std::future<void> &&future0, std::future<void> &&future1);
|
||||
void PushTowait(std::future<void> &&future);
|
||||
void PushTodo(const AnfNodeConfigPtr &conf);
|
||||
void Todo();
|
||||
static void UpdateCaller(const std::string &caller);
|
||||
|
|
|
@ -775,7 +775,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
|||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -845,109 +845,80 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
|
||||
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
|
||||
} else {
|
||||
if (eval_result->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out."
|
||||
<< " Please check the code if there are recursive functions.";
|
||||
}
|
||||
if (eval_result->isa<AbstractError>()) {
|
||||
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
||||
if (eval_result->isa<AbstractTimeOut>() || eval_result->isa<AbstractError>()) {
|
||||
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
return std::make_shared<EvalResult>(eval_result, nullptr);
|
||||
}
|
||||
|
||||
// Eval result of the branches and main.
|
||||
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr asyncResult0 = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr asyncResult1 = std::make_shared<AsyncAbstract>();
|
||||
|
||||
// Control which thread to run.
|
||||
AsyncAbstractPtr asyncRun0 = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr asyncRun1 = std::make_shared<AsyncAbstract>();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
SetUndeterminedFlag(evaluators[0], possible_parent_fg);
|
||||
SetUndeterminedFlag(evaluators[1], possible_parent_fg);
|
||||
// Eval result of the branches and main.
|
||||
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
|
||||
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
||||
std::vector<AsyncAbstractPtr> branchAsyncResults;
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult0, asyncResult_main, asyncRun0);
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult1, asyncResult_main, asyncRun1);
|
||||
|
||||
// Wait for async threads to finish.
|
||||
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
|
||||
// Push to list of running loop
|
||||
asyncRun0->JoinResult(std::make_shared<AbstractScalar>(0));
|
||||
asyncRun1->JoinResult(std::make_shared<AbstractScalar>(0));
|
||||
// Run order
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncRun0); // First order
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncRun1); // Second order
|
||||
for (auto &evaluator : evaluators) {
|
||||
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
|
||||
// Control the order to run.
|
||||
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
|
||||
SetUndeterminedFlag(evaluator, possible_parent_fg);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
|
||||
threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
|
||||
// Wait for async threads to finish.
|
||||
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future));
|
||||
// Push to list of running loop
|
||||
asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order
|
||||
branchAsyncResults.emplace_back(std::move(branchAsyncResult));
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
||||
<< " or " << evaluators[1]->ToString();
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
|
||||
auto branchResult = asyncResult_main->GetResult();
|
||||
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) {
|
||||
auto firstResult = asyncResult_main->GetResult();
|
||||
if (firstResult == nullptr || firstResult->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
|
||||
<< " Please check the code if there are recursive functions.";
|
||||
}
|
||||
if (branchResult->isa<AbstractError>()) {
|
||||
if (firstResult->isa<AbstractError>()) {
|
||||
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
||||
<< branchResult->ToString();
|
||||
<< firstResult->ToString();
|
||||
|
||||
AbstractBasePtrList out_specs;
|
||||
if (NeedWaitForTwoBranches(branchResult)) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString();
|
||||
// The asyncRun0 will eval asyncResult0
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult0);
|
||||
auto result0 = asyncResult0->GetResult();
|
||||
if (result0 == nullptr || result0->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
|
||||
<< " Please check the code if there is recursive function.";
|
||||
size_t len = evaluators.size();
|
||||
if (NeedWaitForBranches(firstResult)) {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
|
||||
HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]);
|
||||
auto result = branchAsyncResults[i]->GetResult();
|
||||
if (result == nullptr || result->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
|
||||
<< " Please check the code if there is recursive function.";
|
||||
}
|
||||
out_specs.push_back(result);
|
||||
}
|
||||
out_specs.push_back(result0);
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString();
|
||||
// The asyncRun1 will eval asyncResult1
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult1);
|
||||
auto result1 = asyncResult1->GetResult();
|
||||
if (result1 == nullptr || result1->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << " is time out."
|
||||
<< " Please check the code if there is recursive function.";
|
||||
}
|
||||
out_specs.push_back(result1);
|
||||
} else {
|
||||
// Next time to get the result of branches.
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
|
||||
(void)asyncResult_main->GetResult();
|
||||
|
||||
// Don't use GetResult
|
||||
auto value0 = asyncResult0->TryGetResult();
|
||||
if (value0) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
|
||||
<< " value0=" << value0->ToString();
|
||||
out_specs.push_back(value0);
|
||||
}
|
||||
|
||||
// Don't use GetResult
|
||||
auto value1 = asyncResult1->TryGetResult();
|
||||
if (value1) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
|
||||
<< " value1=" << value1->ToString();
|
||||
out_specs.push_back(value1);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
// Not wait to get the result of branch.
|
||||
auto result = branchAsyncResults[i]->TryGetResult();
|
||||
if (result) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
|
||||
<< " result =" << result->ToString();
|
||||
out_specs.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ProcessEvalResults(out_specs, out_conf->node());
|
||||
|
|
|
@ -23,13 +23,15 @@ ONE = Tensor([1], mstype.int32)
|
|||
|
||||
@ms_function
|
||||
def f(x):
|
||||
y = f(x - 4)
|
||||
y = ZERO
|
||||
if x < 0:
|
||||
y = f(x - 3)
|
||||
elif x < 3:
|
||||
y = x * f(x - 1)
|
||||
elif x >= 3:
|
||||
elif x < 5:
|
||||
y = x * f(x - 2)
|
||||
else:
|
||||
y = f(x - 4)
|
||||
z = y + 1
|
||||
return z
|
||||
|
||||
|
@ -41,8 +43,10 @@ def fr(x):
|
|||
y = ONE
|
||||
elif x < 3:
|
||||
y = x * fr(x - 1)
|
||||
elif x >= 3:
|
||||
elif x < 5:
|
||||
y = x * fr(x - 2)
|
||||
else:
|
||||
y = fr(x - 4)
|
||||
z = y + 1
|
||||
return z
|
||||
|
||||
|
@ -50,18 +54,20 @@ def fr(x):
|
|||
def test_endless():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor([5], mstype.int32)
|
||||
f(x)
|
||||
with pytest.raises(ValueError):
|
||||
print("endless.")
|
||||
try:
|
||||
f(x)
|
||||
except RuntimeError as e:
|
||||
assert 'endless loop' in str(e)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="backend is not supported yet")
|
||||
def test_recrusive_fun():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor([5], mstype.int32)
|
||||
ret = fr(x)
|
||||
expect = Tensor([36], mstype.int32)
|
||||
expect = Tensor([3], mstype.int32)
|
||||
assert ret == expect
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_recrusive_fun()
|
||||
test_endless()
|
||||
|
|
Loading…
Reference in New Issue