suport more evaluators and fix recrusive_fun testcase
This commit is contained in:
parent
83bcf936b6
commit
b706db3119
|
@ -76,10 +76,9 @@ void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
|
||||||
|
|
||||||
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
|
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
|
||||||
|
|
||||||
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future0, std::future<void> &&future1) {
|
void AnalysisResultCacheMgr::PushTowait(std::future<void> &&future) {
|
||||||
std::lock_guard<std::mutex> lock(lock_);
|
std::lock_guard<std::mutex> lock(lock_);
|
||||||
waiting_.emplace_back(std::move(future0));
|
waiting_.emplace_back(std::move(future));
|
||||||
waiting_.emplace_back(std::move(future1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
|
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
|
||||||
|
|
|
@ -287,7 +287,7 @@ class AnalysisResultCacheMgr {
|
||||||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||||
// Wait for async Eval(conf) to finish.
|
// Wait for async Eval(conf) to finish.
|
||||||
void Wait();
|
void Wait();
|
||||||
void PushTowait(std::future<void> &&future0, std::future<void> &&future1);
|
void PushTowait(std::future<void> &&future);
|
||||||
void PushTodo(const AnfNodeConfigPtr &conf);
|
void PushTodo(const AnfNodeConfigPtr &conf);
|
||||||
void Todo();
|
void Todo();
|
||||||
static void UpdateCaller(const std::string &caller);
|
static void UpdateCaller(const std::string &caller);
|
||||||
|
|
|
@ -775,7 +775,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
||||||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
|
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||||
if (abstract->isa<AbstractFunction>()) {
|
if (abstract->isa<AbstractFunction>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -845,109 +845,80 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
||||||
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
|
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
|
||||||
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
|
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
|
||||||
} else {
|
} else {
|
||||||
if (eval_result->isa<AbstractTimeOut>()) {
|
if (eval_result->isa<AbstractTimeOut>() || eval_result->isa<AbstractError>()) {
|
||||||
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out."
|
MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
||||||
<< " Please check the code if there are recursive functions.";
|
|
||||||
}
|
|
||||||
if (eval_result->isa<AbstractError>()) {
|
|
||||||
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
|
||||||
StaticAnalysisException::Instance().CheckException();
|
StaticAnalysisException::Instance().CheckException();
|
||||||
}
|
}
|
||||||
return std::make_shared<EvalResult>(eval_result, nullptr);
|
return std::make_shared<EvalResult>(eval_result, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Eval result of the branches and main.
|
|
||||||
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
|
|
||||||
AsyncAbstractPtr asyncResult0 = std::make_shared<AsyncAbstract>();
|
|
||||||
AsyncAbstractPtr asyncResult1 = std::make_shared<AsyncAbstract>();
|
|
||||||
|
|
||||||
// Control which thread to run.
|
|
||||||
AsyncAbstractPtr asyncRun0 = std::make_shared<AsyncAbstract>();
|
|
||||||
AsyncAbstractPtr asyncRun1 = std::make_shared<AsyncAbstract>();
|
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(out_conf);
|
MS_EXCEPTION_IF_NULL(out_conf);
|
||||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||||
auto possible_parent_fg = out_conf->node()->func_graph();
|
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||||
SetUndeterminedFlag(evaluators[0], possible_parent_fg);
|
// Eval result of the branches and main.
|
||||||
SetUndeterminedFlag(evaluators[1], possible_parent_fg);
|
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
|
||||||
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
||||||
|
std::vector<AsyncAbstractPtr> branchAsyncResults;
|
||||||
|
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
for (auto &evaluator : evaluators) {
|
||||||
|
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
|
||||||
|
// Control the order to run.
|
||||||
|
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
|
||||||
|
SetUndeterminedFlag(evaluator, possible_parent_fg);
|
||||||
|
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
|
||||||
// Add point to infer thread
|
// Add point to infer thread
|
||||||
HealthPointMgr::GetInstance().AddPoint();
|
HealthPointMgr::GetInstance().AddPoint();
|
||||||
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
|
auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
|
||||||
out_conf, threadId, asyncResult0, asyncResult_main, asyncRun0);
|
threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
|
||||||
|
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
|
|
||||||
// Add point to infer thread
|
|
||||||
HealthPointMgr::GetInstance().AddPoint();
|
|
||||||
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
|
|
||||||
out_conf, threadId, asyncResult1, asyncResult_main, asyncRun1);
|
|
||||||
|
|
||||||
// Wait for async threads to finish.
|
// Wait for async threads to finish.
|
||||||
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
|
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future));
|
||||||
// Push to list of running loop
|
// Push to list of running loop
|
||||||
asyncRun0->JoinResult(std::make_shared<AbstractScalar>(0));
|
asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
|
||||||
asyncRun1->JoinResult(std::make_shared<AbstractScalar>(0));
|
HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order
|
||||||
// Run order
|
branchAsyncResults.emplace_back(std::move(branchAsyncResult));
|
||||||
HealthPointMgr::GetInstance().Add2Schedule(asyncRun0); // First order
|
}
|
||||||
HealthPointMgr::GetInstance().Add2Schedule(asyncRun1); // Second order
|
|
||||||
|
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
||||||
<< " or " << evaluators[1]->ToString();
|
<< " or " << evaluators[1]->ToString();
|
||||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
|
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
|
||||||
auto branchResult = asyncResult_main->GetResult();
|
auto firstResult = asyncResult_main->GetResult();
|
||||||
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) {
|
if (firstResult == nullptr || firstResult->isa<AbstractTimeOut>()) {
|
||||||
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
|
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
|
||||||
<< " Please check the code if there are recursive functions.";
|
<< " Please check the code if there are recursive functions.";
|
||||||
}
|
}
|
||||||
if (branchResult->isa<AbstractError>()) {
|
if (firstResult->isa<AbstractError>()) {
|
||||||
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
|
MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception.";
|
||||||
StaticAnalysisException::Instance().CheckException();
|
StaticAnalysisException::Instance().CheckException();
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
||||||
<< branchResult->ToString();
|
<< firstResult->ToString();
|
||||||
|
|
||||||
AbstractBasePtrList out_specs;
|
AbstractBasePtrList out_specs;
|
||||||
if (NeedWaitForTwoBranches(branchResult)) {
|
size_t len = evaluators.size();
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString();
|
if (NeedWaitForBranches(firstResult)) {
|
||||||
// The asyncRun0 will eval asyncResult0
|
for (size_t i = 0; i < len; ++i) {
|
||||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult0);
|
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
|
||||||
auto result0 = asyncResult0->GetResult();
|
HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]);
|
||||||
if (result0 == nullptr || result0->isa<AbstractTimeOut>()) {
|
auto result = branchAsyncResults[i]->GetResult();
|
||||||
|
if (result == nullptr || result->isa<AbstractTimeOut>()) {
|
||||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
|
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
|
||||||
<< " Please check the code if there is recursive function.";
|
<< " Please check the code if there is recursive function.";
|
||||||
}
|
}
|
||||||
out_specs.push_back(result0);
|
out_specs.push_back(result);
|
||||||
|
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString();
|
|
||||||
// The asyncRun1 will eval asyncResult1
|
|
||||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult1);
|
|
||||||
auto result1 = asyncResult1->GetResult();
|
|
||||||
if (result1 == nullptr || result1->isa<AbstractTimeOut>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << " is time out."
|
|
||||||
<< " Please check the code if there is recursive function.";
|
|
||||||
}
|
}
|
||||||
out_specs.push_back(result1);
|
|
||||||
} else {
|
} else {
|
||||||
// Next time to get the result of branches.
|
// Next time to get the result of branches.
|
||||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
|
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
|
||||||
(void)asyncResult_main->GetResult();
|
(void)asyncResult_main->GetResult();
|
||||||
|
|
||||||
// Don't use GetResult
|
for (size_t i = 0; i < len; ++i) {
|
||||||
auto value0 = asyncResult0->TryGetResult();
|
// Not wait to get the result of branch.
|
||||||
if (value0) {
|
auto result = branchAsyncResults[i]->TryGetResult();
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
|
if (result) {
|
||||||
<< " value0=" << value0->ToString();
|
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
|
||||||
out_specs.push_back(value0);
|
<< " result =" << result->ToString();
|
||||||
|
out_specs.push_back(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't use GetResult
|
|
||||||
auto value1 = asyncResult1->TryGetResult();
|
|
||||||
if (value1) {
|
|
||||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
|
|
||||||
<< " value1=" << value1->ToString();
|
|
||||||
out_specs.push_back(value1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ProcessEvalResults(out_specs, out_conf->node());
|
return ProcessEvalResults(out_specs, out_conf->node());
|
||||||
|
|
|
@ -23,13 +23,15 @@ ONE = Tensor([1], mstype.int32)
|
||||||
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def f(x):
|
def f(x):
|
||||||
y = f(x - 4)
|
y = ZERO
|
||||||
if x < 0:
|
if x < 0:
|
||||||
y = f(x - 3)
|
y = f(x - 3)
|
||||||
elif x < 3:
|
elif x < 3:
|
||||||
y = x * f(x - 1)
|
y = x * f(x - 1)
|
||||||
elif x >= 3:
|
elif x < 5:
|
||||||
y = x * f(x - 2)
|
y = x * f(x - 2)
|
||||||
|
else:
|
||||||
|
y = f(x - 4)
|
||||||
z = y + 1
|
z = y + 1
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
@ -41,8 +43,10 @@ def fr(x):
|
||||||
y = ONE
|
y = ONE
|
||||||
elif x < 3:
|
elif x < 3:
|
||||||
y = x * fr(x - 1)
|
y = x * fr(x - 1)
|
||||||
elif x >= 3:
|
elif x < 5:
|
||||||
y = x * fr(x - 2)
|
y = x * fr(x - 2)
|
||||||
|
else:
|
||||||
|
y = fr(x - 4)
|
||||||
z = y + 1
|
z = y + 1
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
@ -50,18 +54,20 @@ def fr(x):
|
||||||
def test_endless():
|
def test_endless():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
x = Tensor([5], mstype.int32)
|
x = Tensor([5], mstype.int32)
|
||||||
|
try:
|
||||||
f(x)
|
f(x)
|
||||||
with pytest.raises(ValueError):
|
except RuntimeError as e:
|
||||||
print("endless.")
|
assert 'endless loop' in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="backend is not supported yet")
|
||||||
def test_recrusive_fun():
|
def test_recrusive_fun():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
x = Tensor([5], mstype.int32)
|
x = Tensor([5], mstype.int32)
|
||||||
ret = fr(x)
|
ret = fr(x)
|
||||||
expect = Tensor([36], mstype.int32)
|
expect = Tensor([3], mstype.int32)
|
||||||
assert ret == expect
|
assert ret == expect
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_recrusive_fun()
|
test_endless()
|
||||||
|
|
Loading…
Reference in New Issue