This commit is contained in:
lanzhineng 2021-11-30 14:31:02 +08:00
parent 8bdcc68bb7
commit 223e63e414
5 changed files with 117 additions and 194 deletions

View File

@ -23,48 +23,29 @@
namespace mindspore {
namespace abstract {
AbstractBasePtr AsyncAbstract::GetResult() {
auto ret = TryGetResult();
if (ret != nullptr) {
return ret;
}
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
ret = async_task->GetResult();
MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
return ret;
}
thread_local std::string AnalysisSchedule::thread_id_ = "m";
void AnalysisSchedule::Schedule() {
const auto checkPeriod = std::chrono::seconds(3);
while (notExit_ || infer_thread_count_.load() > 0) {
while (run_ || infer_thread_count_.load() > 0) {
std::unique_lock<std::mutex> lock(activate_thread_lock_);
auto ok = activate_thread_cv_.wait_for(lock, checkPeriod,
[this] { return activate_threads_.empty() && !scheduleList_.empty(); });
[this] { return activate_threads_.empty() && !schedule_list_.empty(); });
if (ok) {
SetNextReady();
}
}
MS_LOG(DEBUG) << "Success to exit. The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size();
MS_LOG(DEBUG) << "Success to exit.";
}
void AnalysisSchedule::Yield(const AsyncInferTask *async_infer_task) {
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
// Double check ready()
if (async_infer_task->Ready() == 0) {
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << GetThreadID()
<< " async_infer_task thread id:" << async_infer_task->ThreadID();
(void)activate_threads_.erase(GetThreadID());
if (async_infer_task->ready() == 0) {
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << thread_id()
<< " async_infer_task thread id:" << async_infer_task->thread_id();
(void)activate_threads_.erase(thread_id());
}
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
<< (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
}
activate_thread_cv_.notify_one();
}
@ -90,17 +71,16 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
}
// Free all the locks. Let all the threads continue to run.
std::lock_guard<std::mutex> lock(activate_thread_lock_);
for (auto &item : scheduleList_) {
for (auto &item : schedule_list_) {
item->SetException();
}
scheduleList_.clear();
schedule_list_.clear();
}
void AnalysisSchedule::Stop() {
AsyncInferTaskPtr stopTask = AsyncInferTask::MakeShared(std::make_shared<AsyncAbstract>(), "Stop");
Add2Schedule(stopTask);
MS_LOG(DEBUG) << " Set AnalysisSchedule::Exit . The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size();
AsyncInferTaskPtr stop_task = AsyncInferTask::MakeShared(std::make_shared<AsyncAbstract>(), kStateStop);
Add2Schedule(stop_task);
MS_LOG(DEBUG) << "Set analysis schedule to stop";
}
void AnalysisSchedule::Wait() {
@ -113,9 +93,6 @@ void AnalysisSchedule::Wait() {
if (infer_thread_count_.load() < 0) {
MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
}
if (IS_OUTPUT_ON(DEBUG)) {
AnalysisResultCacheMgr::GetInstance().Todo();
}
MS_LOG(INFO) << "Infer finished.";
StaticAnalysisException::Instance().CheckException();
}
@ -123,55 +100,64 @@ void AnalysisSchedule::Wait() {
void AnalysisSchedule::Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr) {
std::lock_guard<std::mutex> lock(activate_thread_lock_);
MS_EXCEPTION_IF_NULL(async_infer_task_ptr);
scheduleList_.push_back(async_infer_task_ptr);
schedule_list_.push_back(async_infer_task_ptr);
activate_thread_cv_.notify_one();
MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->ThreadID() << " address: " << async_infer_task_ptr.get()
MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->thread_id() << " address: " << async_infer_task_ptr.get()
<< " The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size();
<< " schedule list size: " << schedule_list_.size();
}
void AnalysisSchedule::SetNextReady() {
if (scheduleList_.empty()) {
if (schedule_list_.empty()) {
return;
}
// Exit Flag
if (scheduleList_.front()->ThreadID() == "Stop") {
notExit_ = false;
scheduleList_.pop_front();
if (schedule_list_.front()->thread_id() == kStateStop) {
run_ = false;
schedule_list_.pop_front();
return;
}
// Check if enter endless loop
auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
auto it = std::find_if(schedule_list_.begin(), schedule_list_.end(), [](const auto &item) {
MS_EXCEPTION_IF_NULL(item);
return item->HasResult();
});
if (it == scheduleList_.end()) {
if (IntToSize(infer_thread_count_.load()) >= scheduleList_.size()) {
if (it == schedule_list_.end()) {
if (IntToSize(infer_thread_count_.load()) >= schedule_list_.size()) {
MS_LOG(DEBUG) << "There is some task to be added. Please wait.";
return;
}
// Enter endless loop if there is not ready result.
(void)activate_threads_.insert(scheduleList_.front()->ThreadID());
(void)activate_threads_.insert(schedule_list_.front()->thread_id());
// Let the first thread to trigger endless loop exception.
MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
<< scheduleList_.front().get() << " The active thread count: " << activate_threads_.size();
scheduleList_.front()->SetEndLessLoopException();
scheduleList_.pop_front();
<< schedule_list_.front().get() << " The active thread count: " << activate_threads_.size();
schedule_list_.front()->SetEndLessLoopException();
schedule_list_.pop_front();
return;
}
auto async_task = *it;
(void)activate_threads_.insert(async_task->ThreadID());
(void)activate_threads_.insert(async_task->thread_id());
async_task->SetReady();
(void)scheduleList_.erase(it);
(void)schedule_list_.erase(it);
MS_LOG(DEBUG) << " Success to SetReady. The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_ << " schedule list size: " << scheduleList_.size()
<< " async: " << async_task->ThreadID() << " address: " << async_task.get();
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << schedule_list_.size() << " async: " << async_task->thread_id()
<< " address: " << async_task.get();
}
// The thread id format is XXXX.YYYY.ZZZZ
thread_local std::string localThreadID = "1";
void AnalysisSchedule::SetThreadID(const std::string &threadID) { localThreadID = threadID; }
std::string &AnalysisSchedule::GetThreadID() { return localThreadID; }
AbstractBasePtr AsyncAbstract::GetResult() {
auto ret = TryGetResult();
if (ret != nullptr) {
return ret;
}
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
ret = async_task->GetResult();
MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
return ret;
}
AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
if (resolved_ != nullptr) {
@ -223,12 +209,6 @@ void AnalysisResultCacheMgr::Clear() {
cache_.clear();
switch_cache_.clear();
switch_cache_for_check_.clear();
todo_.clear();
}
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
std::lock_guard<std::mutex> lock(todo_lock_);
todo_.push_back(conf);
}
void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
@ -240,16 +220,6 @@ void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
}
}
AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) {
// don't call lock_.lock(). switch_cache is protected. and it waits for result.
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
// Conf has been visited and set value.
if (async_eval_result != nullptr) {
return async_eval_result->TryGetResult();
}
return nullptr;
}
AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
// don't call lock_.lock(). switch_cache is protected. and it waits for result.
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
@ -270,7 +240,7 @@ void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const A
AsyncAbstractPtr async_eval_result = cache->get(conf);
if (async_eval_result == nullptr) {
async_eval_result = std::make_shared<AsyncAbstract>();
async_eval_result->SetResult(arg);
async_eval_result->set_result(arg);
cache->set(conf, async_eval_result);
} else {
auto ab1 = async_eval_result->TryGetResult();
@ -280,12 +250,9 @@ void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const A
absList.push_back(ab1);
// Join two branches's result
auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
async_eval_result->SetResult(joined_result->abstract());
if (!(*joined_result == *ab1)) {
PushTodo(conf);
}
async_eval_result->set_result(joined_result->abstract());
} else {
async_eval_result->SetResult(arg);
async_eval_result->set_result(arg);
}
}
}
@ -298,32 +265,6 @@ void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const
SetCacheValue(conf, arg, &switch_cache_);
}
void AnalysisResultCacheMgr::Todo() {
std::lock_guard<std::mutex> lock(todo_lock_);
while (!todo_.empty()) {
AnfNodeConfigPtr conf = todo_.front();
MS_EXCEPTION_IF_NULL(conf);
todo_.pop_front();
if (GetValue(conf) == nullptr) {
MS_LOG(INFO) << conf->node()->ToString() << " not in globle cache.";
continue;
}
if (TryGetSwitchValue(conf) == nullptr) {
MS_LOG(INFO) << conf->node()->ToString() << " not in switch cache";
continue;
}
auto switch_value = TryGetSwitchValue(conf);
auto abstract = GetValue(conf)->abstract();
MS_EXCEPTION_IF_NULL(switch_value);
MS_EXCEPTION_IF_NULL(abstract);
if (!(*abstract == *switch_value)) {
MS_LOG(WARNING) << " Switch Value is not eq. "
<< " switchCache: " << switch_value->ToString() << " globleCache: " << abstract->ToString()
<< "\t\tConf: " << conf->ToString();
}
}
}
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
std::ostringstream buffer;
for (const auto &item : args_spec_list) {

View File

@ -49,8 +49,8 @@ class AnalysisSchedule {
static AnalysisSchedule instance;
return instance;
}
static void SetThreadID(const std::string &caller);
static std::string &GetThreadID();
static void set_thread_id(const std::string &thread_id) { thread_id_ = thread_id; }
static std::string &thread_id() { return thread_id_; }
void HandleException(const std::exception &ex);
void Stop();
void Wait();
@ -61,10 +61,7 @@ class AnalysisSchedule {
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
activate_threads_.clear();
MS_LOG(DEBUG) << " Get activate_thread_lock. The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
<< (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
MS_LOG(DEBUG) << "Infer return to main thread.";
}
activate_thread_cv_.notify_one();
}
@ -73,7 +70,7 @@ class AnalysisSchedule {
infer_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size();
<< " schedule list size: " << schedule_list_.size();
}
void DecreaseThreadCount() {
@ -84,11 +81,11 @@ class AnalysisSchedule {
infer_thread_cv_.notify_one();
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
std::lock_guard<std::mutex> active_lock(activate_thread_lock_);
activate_threads_.clear();
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
<< " The infer_thread_count: " << infer_thread_count_
<< " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
<< " schedule list size: " << schedule_list_.size() << " thread: " << thread_id() + " "
<< (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
}
activate_thread_cv_.notify_one();
@ -103,13 +100,15 @@ class AnalysisSchedule {
}
AnalysisSchedule() { Start(); }
std::atomic<int> infer_thread_count_{0};
bool notExit_{true};
bool run_{true};
std::mutex infer_thread_lock_;
std::condition_variable infer_thread_cv_;
std::mutex activate_thread_lock_;
std::condition_variable activate_thread_cv_;
std::list<AsyncInferTaskPtr> scheduleList_;
std::list<AsyncInferTaskPtr> schedule_list_;
std::set<std::string> activate_threads_;
const std::string kStateStop = "Stop";
static thread_local std::string thread_id_;
};
template <typename KeyType, typename ValueType, typename CacheType>
@ -210,7 +209,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
public:
AsyncAbstract() = default;
~AsyncAbstract() = default;
AbstractBasePtr GetResult();
AbstractBasePtr TryGetResult() {
std::lock_guard<std::mutex> lock(lock_);
return result_;
@ -219,14 +218,11 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
std::lock_guard<std::mutex> lock(lock_);
return result_ != nullptr;
}
void SetResult(const AbstractBasePtr &result) {
void set_result(const AbstractBasePtr &result) {
MS_EXCEPTION_IF_NULL(result);
std::lock_guard<std::mutex> lock(lock_);
result_ = result;
}
AbstractBasePtr GetResult();
std::string ToString() {
std::ostringstream buffer;
std::lock_guard<std::mutex> lock(lock_);
@ -271,7 +267,7 @@ class AsyncAbstractFuncAtom : public AbstractFuncAtom {
AbstractFunctionPtr GetUnique() override;
std::string ToString() const;
std::string ToString() const override;
private:
// Resolved AbstractFunction after fully analyzed.
@ -284,14 +280,14 @@ using AsyncAbstractFuncAtomPtr = std::shared_ptr<AsyncAbstractFuncAtom>;
class AsyncInferTask {
public:
explicit AsyncInferTask(const std::string &threadId, const AsyncAbstractPtr &abstract)
: threadId_(threadId), abstract_ptr_(abstract) {}
explicit AsyncInferTask(const std::string &thread_id, const AsyncAbstractPtr &abstract)
: thread_id_(thread_id), abstract_ptr_(abstract) {}
~AsyncInferTask() = default;
static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &threadId = "") {
std::string thread_id = threadId;
static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") {
std::string thread_id = thread;
if (thread_id == "") {
thread_id = AnalysisSchedule::GetInstance().GetThreadID();
thread_id = AnalysisSchedule::GetInstance().thread_id();
}
MS_EXCEPTION_IF_NULL(abstract);
auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
@ -300,8 +296,8 @@ class AsyncInferTask {
}
bool HasResult() { return abstract_ptr_->HasResult(); }
int Ready() const { return ready_; }
std::string ThreadID() const { return threadId_; }
int ready() const { return ready_; }
std::string thread_id() const { return thread_id_; }
AbstractBasePtr GetResult() {
StaticAnalysisException::Instance().CheckException();
@ -316,8 +312,7 @@ class AsyncInferTask {
lock.lock();
condition_var_.wait(lock, [this] { return ready_; });
MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << threadId_
<< " GetThreadId: " << AnalysisSchedule::GetInstance().GetThreadID();
MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << thread_id_;
ProcessResult();
auto ans = abstract_ptr_->TryGetResult();
MS_EXCEPTION_IF_NULL(ans);
@ -329,7 +324,7 @@ class AsyncInferTask {
std::lock_guard<std::mutex> lock(lock_);
ready_ = ready_ | 0b001; // Set the first bit = 1
MS_LOG(DEBUG) << this << " notify ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get()
<< " threadId: " << threadId_;
<< " thread_id: " << thread_id_;
}
condition_var_.notify_one();
}
@ -365,11 +360,10 @@ class AsyncInferTask {
void ProcessResult() {
HandleEndLessLoopException();
StaticAnalysisException::Instance().CheckException();
MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " threadId: " << threadId_
<< " GetThreadId:" << AnalysisSchedule::GetInstance().GetThreadID()
MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " thread_id: " << thread_id_
<< " result: " << abstract_ptr_->TryGetResult().get();
}
std::string threadId_;
std::string thread_id_;
AsyncAbstractPtr abstract_ptr_;
std::mutex lock_;
std::condition_variable condition_var_;
@ -413,11 +407,8 @@ class AnalysisResultCacheMgr {
void Clear();
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
void PushTodo(const AnfNodeConfigPtr &conf);
void Todo();
void InitSwitchValue(const AnfNodeConfigPtr &conf);
AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale);
const_iterator begin() { return cache_.begin(); }
const_iterator end() { return cache_.end(); }
@ -432,8 +423,6 @@ class AnalysisResultCacheMgr {
void SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale, AnalysisConfigAsyncResultCache *cache);
std::mutex lock_;
std::mutex todo_lock_;
std::list<AnfNodeConfigPtr> todo_;
AnalysisConfigResultCache cache_;
AnalysisConfigAsyncResultCache switch_cache_;
AnalysisConfigAsyncResultCache switch_cache_for_check_;
@ -441,7 +430,7 @@ class AnalysisResultCacheMgr {
std::string ArgsToString(const AbstractBasePtrList &args_spec_list);
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::GetThreadID() + ":"; }
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::thread_id() + ":"; }
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_

View File

@ -226,8 +226,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
MS_EXCEPTION_IF_NULL(parent_context_);
MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
if (parent_context_->func_graph() != nullptr) {
MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::GetThreadID() << ":"
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::GetThreadID() << ":"
MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::thread_id() << ":"
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::thread_id() << ":"
<< fg->ToString() << "();";
}

View File

@ -208,8 +208,10 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
eval_result = EvalCNode(cnode, conf);
trace::TraceEvalCNodeLeave();
} else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name()
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph");
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString()
<< "(type:" << node->type_name()
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
<< " conf: " << conf->ToString();
}
#ifdef DEBUG
@ -729,6 +731,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
namespace {
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<AbstractFunction>()) {
@ -745,25 +748,25 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
}
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
const std::string &threadID, AsyncAbstractPtr async_result_branch,
AsyncAbstractPtr async_result_main, AsyncInferTaskPtr async_run_flag,
const std::string &thread_id, AsyncAbstractPtr async_result_branch,
AsyncAbstractPtr async_result_main, AsyncInferTaskPtr async_task,
const trace::TraceGraphEvalStack &graph_evals,
const trace::TraceCNodeEvalStack &trace_c_node_evals) {
AnalysisSchedule::SetThreadID(threadID);
AnalysisSchedule::set_thread_id(thread_id);
// Restore trace stack for dump stack when there is exception.
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
trace::TraceGraphEvalStackPrepare(graph_evals);
try {
// Wait for Signal to run
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting.";
(void)async_run_flag->GetResult();
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running.";
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
(void)async_task->GetResult();
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
// Acquire GIL for eval to callback python.
EvalResultPtr result;
{
py::gil_scoped_acquire pyGuard;
py::gil_scoped_acquire py_guard;
result = eval->Run(engine, args_conf_list, out_conf);
}
MS_EXCEPTION_IF_NULL(result);
@ -772,31 +775,21 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
// Check the branch value to be compatible with the other branch value.
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
// Broaden the result of switch(c,t,f)()
auto broadAbstract = result->abstract()->Broaden();
auto broaden_abstract = result->abstract()->Broaden();
// Notify the thread of waiting for branch value and the main thread to continue.
async_result_branch->SetResult(broadAbstract);
async_result_main->SetResult(broadAbstract);
// Thread number will be drop when thread exits.
AnalysisSchedule::GetInstance().DecreaseThreadCount();
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
async_result_branch->set_result(broaden_abstract);
async_result_main->set_result(broaden_abstract);
MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString();
} catch (const std::exception &e1) {
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
async_result_main->SetResult(abstractErrPtr);
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
AnalysisSchedule::GetInstance().HandleException(e1);
try {
// Thread number will be drop when thread exits.
AnalysisSchedule::GetInstance().DecreaseThreadCount();
} catch (const std::exception &e2) {
MS_LOG(DEBUG) << "AnalysisSchedule::GetInstance().DecreaseThreadCount() threw exception.";
}
AnalysisSchedule::GetInstance().HandleException(ex);
}
// Thread number will be drop when thread exits.
AnalysisSchedule::GetInstance().DecreaseThreadCount();
}
namespace {
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
AbstractBasePtrList *out_specs) {
@ -871,49 +864,49 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
auto possible_parent_fg = out_conf->node()->func_graph();
// Eval result of the main.
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr async_result_main = std::make_shared<AsyncAbstract>();
// Eval result of the branches
std::vector<AsyncAbstractPtr> branchAsyncResults;
std::vector<AsyncAbstractPtr> async_result_branches;
for (auto &evaluator : evaluators) {
static std::atomic<int> idCount{0};
std::string threadId = AnalysisSchedule::GetThreadID() + "." + std::to_string(idCount.fetch_add(1));
static std::atomic<int> id_count{0};
std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
MS_EXCEPTION_IF_NULL(evaluator);
SetUndeterminedFlag(evaluator, possible_parent_fg);
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>();
// Control the order to run.
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
AsyncInferTaskPtr asyncTask = AsyncInferTask::MakeShared(asyncRunOrder, threadId);
// Add point to the async thread.
AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
control_run_order->set_result(std::make_shared<AbstractScalar>(1));
AsyncInferTaskPtr async_task = AsyncInferTask::MakeShared(control_run_order, thread_id);
AnalysisSchedule::GetInstance().IncreaseThreadCount();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
auto thread =
std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId, branchAsyncResult,
asyncResult_main, asyncTask, trace::GetCurrentGraphEvalStack(), trace::GetCNodeDebugStack());
auto thread = std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, thread_id,
async_result_branch, async_result_main, async_task, trace::GetCurrentGraphEvalStack(),
trace::GetCNodeDebugStack());
thread.detach();
// Push to list of running loop
asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));
MS_LOG(DEBUG) << " add to schedule: " << asyncTask.get();
AnalysisSchedule::GetInstance().Add2Schedule(asyncTask); // Activate order witch child thread.
(void)branchAsyncResults.emplace_back(std::move(branchAsyncResult));
MS_LOG(DEBUG) << " add to schedule: " << async_task.get();
AnalysisSchedule::GetInstance().Add2Schedule(async_task); // Activate order witch child thread.
(void)async_result_branches.emplace_back(std::move(async_result_branch));
}
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
<< " or " << evaluators[1]->ToString() << "...";
auto firstResult = asyncResult_main->GetResult();
MS_EXCEPTION_IF_NULL(firstResult);
auto first_result = async_result_main->GetResult();
MS_EXCEPTION_IF_NULL(first_result);
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
<< firstResult->ToString();
<< first_result->ToString();
AbstractBasePtrList out_specs;
size_t len = evaluators.size();
if (NeedWaitForBranches(firstResult)) {
BuildPossibleSpecs(firstResult, branchAsyncResults, &out_specs);
if (NeedWaitForBranches(first_result)) {
BuildPossibleSpecs(first_result, async_result_branches, &out_specs);
} else {
for (size_t i = 0; i < len; ++i) {
// Not wait to get the result of branch.
auto result = branchAsyncResults[i]->TryGetResult();
auto result = async_result_branches[i]->TryGetResult();
if (result) {
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
<< " result: " << result->ToString();

View File

@ -133,7 +133,7 @@ class AnfNodeConfig : public Config {
AnalysisEnginePtr engine() const { return engine_.lock(); }
size_t hash() const {
size_t hash() const override {
std::size_t node_hash = PointerHash<AnfNodePtr>{}(node_);
if (context_->IsDummyContext()) {
return node_hash;