codex
This commit is contained in:
parent
8bdcc68bb7
commit
223e63e414
|
@ -23,48 +23,29 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
||||
AbstractBasePtr AsyncAbstract::GetResult() {
|
||||
auto ret = TryGetResult();
|
||||
if (ret != nullptr) {
|
||||
return ret;
|
||||
}
|
||||
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
|
||||
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
|
||||
ret = async_task->GetResult();
|
||||
MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
|
||||
return ret;
|
||||
}
|
||||
thread_local std::string AnalysisSchedule::thread_id_ = "m";
|
||||
|
||||
void AnalysisSchedule::Schedule() {
|
||||
const auto checkPeriod = std::chrono::seconds(3);
|
||||
while (notExit_ || infer_thread_count_.load() > 0) {
|
||||
while (run_ || infer_thread_count_.load() > 0) {
|
||||
std::unique_lock<std::mutex> lock(activate_thread_lock_);
|
||||
auto ok = activate_thread_cv_.wait_for(lock, checkPeriod,
|
||||
[this] { return activate_threads_.empty() && !scheduleList_.empty(); });
|
||||
[this] { return activate_threads_.empty() && !schedule_list_.empty(); });
|
||||
if (ok) {
|
||||
SetNextReady();
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Success to exit. The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size();
|
||||
MS_LOG(DEBUG) << "Success to exit.";
|
||||
}
|
||||
|
||||
void AnalysisSchedule::Yield(const AsyncInferTask *async_infer_task) {
|
||||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
// Double check ready()
|
||||
if (async_infer_task->Ready() == 0) {
|
||||
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << GetThreadID()
|
||||
<< " async_infer_task thread id:" << async_infer_task->ThreadID();
|
||||
(void)activate_threads_.erase(GetThreadID());
|
||||
if (async_infer_task->ready() == 0) {
|
||||
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << thread_id()
|
||||
<< " async_infer_task thread id:" << async_infer_task->thread_id();
|
||||
(void)activate_threads_.erase(thread_id());
|
||||
}
|
||||
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
|
||||
<< (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
}
|
||||
|
@ -90,17 +71,16 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
|
|||
}
|
||||
// Free all the locks. Let all the threads continue to run.
|
||||
std::lock_guard<std::mutex> lock(activate_thread_lock_);
|
||||
for (auto &item : scheduleList_) {
|
||||
for (auto &item : schedule_list_) {
|
||||
item->SetException();
|
||||
}
|
||||
scheduleList_.clear();
|
||||
schedule_list_.clear();
|
||||
}
|
||||
|
||||
void AnalysisSchedule::Stop() {
|
||||
AsyncInferTaskPtr stopTask = AsyncInferTask::MakeShared(std::make_shared<AsyncAbstract>(), "Stop");
|
||||
Add2Schedule(stopTask);
|
||||
MS_LOG(DEBUG) << " Set AnalysisSchedule::Exit . The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size();
|
||||
AsyncInferTaskPtr stop_task = AsyncInferTask::MakeShared(std::make_shared<AsyncAbstract>(), kStateStop);
|
||||
Add2Schedule(stop_task);
|
||||
MS_LOG(DEBUG) << "Set analysis schedule to stop";
|
||||
}
|
||||
|
||||
void AnalysisSchedule::Wait() {
|
||||
|
@ -113,9 +93,6 @@ void AnalysisSchedule::Wait() {
|
|||
if (infer_thread_count_.load() < 0) {
|
||||
MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
|
||||
}
|
||||
if (IS_OUTPUT_ON(DEBUG)) {
|
||||
AnalysisResultCacheMgr::GetInstance().Todo();
|
||||
}
|
||||
MS_LOG(INFO) << "Infer finished.";
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
|
@ -123,55 +100,64 @@ void AnalysisSchedule::Wait() {
|
|||
void AnalysisSchedule::Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr) {
|
||||
std::lock_guard<std::mutex> lock(activate_thread_lock_);
|
||||
MS_EXCEPTION_IF_NULL(async_infer_task_ptr);
|
||||
scheduleList_.push_back(async_infer_task_ptr);
|
||||
schedule_list_.push_back(async_infer_task_ptr);
|
||||
activate_thread_cv_.notify_one();
|
||||
MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->ThreadID() << " address: " << async_infer_task_ptr.get()
|
||||
MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->thread_id() << " address: " << async_infer_task_ptr.get()
|
||||
<< " The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size();
|
||||
<< " schedule list size: " << schedule_list_.size();
|
||||
}
|
||||
void AnalysisSchedule::SetNextReady() {
|
||||
if (scheduleList_.empty()) {
|
||||
if (schedule_list_.empty()) {
|
||||
return;
|
||||
}
|
||||
// Exit Flag
|
||||
if (scheduleList_.front()->ThreadID() == "Stop") {
|
||||
notExit_ = false;
|
||||
scheduleList_.pop_front();
|
||||
if (schedule_list_.front()->thread_id() == kStateStop) {
|
||||
run_ = false;
|
||||
schedule_list_.pop_front();
|
||||
return;
|
||||
}
|
||||
// Check if enter endless loop
|
||||
auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
|
||||
auto it = std::find_if(schedule_list_.begin(), schedule_list_.end(), [](const auto &item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return item->HasResult();
|
||||
});
|
||||
if (it == scheduleList_.end()) {
|
||||
if (IntToSize(infer_thread_count_.load()) >= scheduleList_.size()) {
|
||||
if (it == schedule_list_.end()) {
|
||||
if (IntToSize(infer_thread_count_.load()) >= schedule_list_.size()) {
|
||||
MS_LOG(DEBUG) << "There is some task to be added. Please wait.";
|
||||
return;
|
||||
}
|
||||
// Enter endless loop if there is not ready result.
|
||||
(void)activate_threads_.insert(scheduleList_.front()->ThreadID());
|
||||
(void)activate_threads_.insert(schedule_list_.front()->thread_id());
|
||||
// Let the first thread to trigger endless loop exception.
|
||||
MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
|
||||
<< scheduleList_.front().get() << " The active thread count: " << activate_threads_.size();
|
||||
scheduleList_.front()->SetEndLessLoopException();
|
||||
scheduleList_.pop_front();
|
||||
<< schedule_list_.front().get() << " The active thread count: " << activate_threads_.size();
|
||||
schedule_list_.front()->SetEndLessLoopException();
|
||||
schedule_list_.pop_front();
|
||||
return;
|
||||
}
|
||||
auto async_task = *it;
|
||||
(void)activate_threads_.insert(async_task->ThreadID());
|
||||
(void)activate_threads_.insert(async_task->thread_id());
|
||||
async_task->SetReady();
|
||||
(void)scheduleList_.erase(it);
|
||||
(void)schedule_list_.erase(it);
|
||||
MS_LOG(DEBUG) << " Success to SetReady. The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_ << " schedule list size: " << scheduleList_.size()
|
||||
<< " async: " << async_task->ThreadID() << " address: " << async_task.get();
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << schedule_list_.size() << " async: " << async_task->thread_id()
|
||||
<< " address: " << async_task.get();
|
||||
}
|
||||
// The thread id format is XXXX.YYYY.ZZZZ
|
||||
thread_local std::string localThreadID = "1";
|
||||
void AnalysisSchedule::SetThreadID(const std::string &threadID) { localThreadID = threadID; }
|
||||
|
||||
std::string &AnalysisSchedule::GetThreadID() { return localThreadID; }
|
||||
AbstractBasePtr AsyncAbstract::GetResult() {
|
||||
auto ret = TryGetResult();
|
||||
if (ret != nullptr) {
|
||||
return ret;
|
||||
}
|
||||
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
|
||||
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
|
||||
ret = async_task->GetResult();
|
||||
MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
|
||||
if (resolved_ != nullptr) {
|
||||
|
@ -223,12 +209,6 @@ void AnalysisResultCacheMgr::Clear() {
|
|||
cache_.clear();
|
||||
switch_cache_.clear();
|
||||
switch_cache_for_check_.clear();
|
||||
todo_.clear();
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
|
||||
std::lock_guard<std::mutex> lock(todo_lock_);
|
||||
todo_.push_back(conf);
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
|
@ -240,16 +220,6 @@ void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
|
|||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
// don't call lock_.lock(). switch_cache is protected. and it waits for result.
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
// Conf has been visited and set value.
|
||||
if (async_eval_result != nullptr) {
|
||||
return async_eval_result->TryGetResult();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
// don't call lock_.lock(). switch_cache is protected. and it waits for result.
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
|
@ -270,7 +240,7 @@ void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const A
|
|||
AsyncAbstractPtr async_eval_result = cache->get(conf);
|
||||
if (async_eval_result == nullptr) {
|
||||
async_eval_result = std::make_shared<AsyncAbstract>();
|
||||
async_eval_result->SetResult(arg);
|
||||
async_eval_result->set_result(arg);
|
||||
cache->set(conf, async_eval_result);
|
||||
} else {
|
||||
auto ab1 = async_eval_result->TryGetResult();
|
||||
|
@ -280,12 +250,9 @@ void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const A
|
|||
absList.push_back(ab1);
|
||||
// Join two branches's result
|
||||
auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
|
||||
async_eval_result->SetResult(joined_result->abstract());
|
||||
if (!(*joined_result == *ab1)) {
|
||||
PushTodo(conf);
|
||||
}
|
||||
async_eval_result->set_result(joined_result->abstract());
|
||||
} else {
|
||||
async_eval_result->SetResult(arg);
|
||||
async_eval_result->set_result(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -298,32 +265,6 @@ void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const
|
|||
SetCacheValue(conf, arg, &switch_cache_);
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::Todo() {
|
||||
std::lock_guard<std::mutex> lock(todo_lock_);
|
||||
while (!todo_.empty()) {
|
||||
AnfNodeConfigPtr conf = todo_.front();
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
todo_.pop_front();
|
||||
if (GetValue(conf) == nullptr) {
|
||||
MS_LOG(INFO) << conf->node()->ToString() << " not in globle cache.";
|
||||
continue;
|
||||
}
|
||||
if (TryGetSwitchValue(conf) == nullptr) {
|
||||
MS_LOG(INFO) << conf->node()->ToString() << " not in switch cache";
|
||||
continue;
|
||||
}
|
||||
auto switch_value = TryGetSwitchValue(conf);
|
||||
auto abstract = GetValue(conf)->abstract();
|
||||
MS_EXCEPTION_IF_NULL(switch_value);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (!(*abstract == *switch_value)) {
|
||||
MS_LOG(WARNING) << " Switch Value is not eq. "
|
||||
<< " switchCache: " << switch_value->ToString() << " globleCache: " << abstract->ToString()
|
||||
<< "\t\tConf: " << conf->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
|
||||
std::ostringstream buffer;
|
||||
for (const auto &item : args_spec_list) {
|
||||
|
|
|
@ -49,8 +49,8 @@ class AnalysisSchedule {
|
|||
static AnalysisSchedule instance;
|
||||
return instance;
|
||||
}
|
||||
static void SetThreadID(const std::string &caller);
|
||||
static std::string &GetThreadID();
|
||||
static void set_thread_id(const std::string &thread_id) { thread_id_ = thread_id; }
|
||||
static std::string &thread_id() { return thread_id_; }
|
||||
void HandleException(const std::exception &ex);
|
||||
void Stop();
|
||||
void Wait();
|
||||
|
@ -61,10 +61,7 @@ class AnalysisSchedule {
|
|||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
activate_threads_.clear();
|
||||
MS_LOG(DEBUG) << " Get activate_thread_lock. The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
|
||||
<< (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
|
||||
MS_LOG(DEBUG) << "Infer return to main thread.";
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
}
|
||||
|
@ -73,7 +70,7 @@ class AnalysisSchedule {
|
|||
infer_thread_count_.fetch_add(1);
|
||||
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size();
|
||||
<< " schedule list size: " << schedule_list_.size();
|
||||
}
|
||||
|
||||
void DecreaseThreadCount() {
|
||||
|
@ -84,11 +81,11 @@ class AnalysisSchedule {
|
|||
infer_thread_cv_.notify_one();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
std::lock_guard<std::mutex> active_lock(activate_thread_lock_);
|
||||
activate_threads_.clear();
|
||||
MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
|
||||
<< " The infer_thread_count: " << infer_thread_count_
|
||||
<< " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
|
||||
<< " schedule list size: " << schedule_list_.size() << " thread: " << thread_id() + " "
|
||||
<< (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
|
@ -103,13 +100,15 @@ class AnalysisSchedule {
|
|||
}
|
||||
AnalysisSchedule() { Start(); }
|
||||
std::atomic<int> infer_thread_count_{0};
|
||||
bool notExit_{true};
|
||||
bool run_{true};
|
||||
std::mutex infer_thread_lock_;
|
||||
std::condition_variable infer_thread_cv_;
|
||||
std::mutex activate_thread_lock_;
|
||||
std::condition_variable activate_thread_cv_;
|
||||
std::list<AsyncInferTaskPtr> scheduleList_;
|
||||
std::list<AsyncInferTaskPtr> schedule_list_;
|
||||
std::set<std::string> activate_threads_;
|
||||
const std::string kStateStop = "Stop";
|
||||
static thread_local std::string thread_id_;
|
||||
};
|
||||
|
||||
template <typename KeyType, typename ValueType, typename CacheType>
|
||||
|
@ -210,7 +209,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
public:
|
||||
AsyncAbstract() = default;
|
||||
~AsyncAbstract() = default;
|
||||
|
||||
AbstractBasePtr GetResult();
|
||||
AbstractBasePtr TryGetResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return result_;
|
||||
|
@ -219,14 +218,11 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return result_ != nullptr;
|
||||
}
|
||||
void SetResult(const AbstractBasePtr &result) {
|
||||
void set_result(const AbstractBasePtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
}
|
||||
|
||||
AbstractBasePtr GetResult();
|
||||
|
||||
std::string ToString() {
|
||||
std::ostringstream buffer;
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
|
@ -271,7 +267,7 @@ class AsyncAbstractFuncAtom : public AbstractFuncAtom {
|
|||
|
||||
AbstractFunctionPtr GetUnique() override;
|
||||
|
||||
std::string ToString() const;
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
// Resolved AbstractFunction after fully analyzed.
|
||||
|
@ -284,14 +280,14 @@ using AsyncAbstractFuncAtomPtr = std::shared_ptr<AsyncAbstractFuncAtom>;
|
|||
|
||||
class AsyncInferTask {
|
||||
public:
|
||||
explicit AsyncInferTask(const std::string &threadId, const AsyncAbstractPtr &abstract)
|
||||
: threadId_(threadId), abstract_ptr_(abstract) {}
|
||||
explicit AsyncInferTask(const std::string &thread_id, const AsyncAbstractPtr &abstract)
|
||||
: thread_id_(thread_id), abstract_ptr_(abstract) {}
|
||||
~AsyncInferTask() = default;
|
||||
|
||||
static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &threadId = "") {
|
||||
std::string thread_id = threadId;
|
||||
static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") {
|
||||
std::string thread_id = thread;
|
||||
if (thread_id == "") {
|
||||
thread_id = AnalysisSchedule::GetInstance().GetThreadID();
|
||||
thread_id = AnalysisSchedule::GetInstance().thread_id();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
|
||||
|
@ -300,8 +296,8 @@ class AsyncInferTask {
|
|||
}
|
||||
|
||||
bool HasResult() { return abstract_ptr_->HasResult(); }
|
||||
int Ready() const { return ready_; }
|
||||
std::string ThreadID() const { return threadId_; }
|
||||
int ready() const { return ready_; }
|
||||
std::string thread_id() const { return thread_id_; }
|
||||
|
||||
AbstractBasePtr GetResult() {
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
|
@ -316,8 +312,7 @@ class AsyncInferTask {
|
|||
|
||||
lock.lock();
|
||||
condition_var_.wait(lock, [this] { return ready_; });
|
||||
MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << threadId_
|
||||
<< " GetThreadId: " << AnalysisSchedule::GetInstance().GetThreadID();
|
||||
MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << thread_id_;
|
||||
ProcessResult();
|
||||
auto ans = abstract_ptr_->TryGetResult();
|
||||
MS_EXCEPTION_IF_NULL(ans);
|
||||
|
@ -329,7 +324,7 @@ class AsyncInferTask {
|
|||
std::lock_guard<std::mutex> lock(lock_);
|
||||
ready_ = ready_ | 0b001; // Set the first bit = 1
|
||||
MS_LOG(DEBUG) << this << " notify ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get()
|
||||
<< " threadId: " << threadId_;
|
||||
<< " thread_id: " << thread_id_;
|
||||
}
|
||||
condition_var_.notify_one();
|
||||
}
|
||||
|
@ -365,11 +360,10 @@ class AsyncInferTask {
|
|||
void ProcessResult() {
|
||||
HandleEndLessLoopException();
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " threadId: " << threadId_
|
||||
<< " GetThreadId:" << AnalysisSchedule::GetInstance().GetThreadID()
|
||||
MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " thread_id: " << thread_id_
|
||||
<< " result: " << abstract_ptr_->TryGetResult().get();
|
||||
}
|
||||
std::string threadId_;
|
||||
std::string thread_id_;
|
||||
AsyncAbstractPtr abstract_ptr_;
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
|
@ -413,11 +407,8 @@ class AnalysisResultCacheMgr {
|
|||
void Clear();
|
||||
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
|
||||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||
void PushTodo(const AnfNodeConfigPtr &conf);
|
||||
void Todo();
|
||||
void InitSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale);
|
||||
const_iterator begin() { return cache_.begin(); }
|
||||
const_iterator end() { return cache_.end(); }
|
||||
|
@ -432,8 +423,6 @@ class AnalysisResultCacheMgr {
|
|||
void SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale, AnalysisConfigAsyncResultCache *cache);
|
||||
|
||||
std::mutex lock_;
|
||||
std::mutex todo_lock_;
|
||||
std::list<AnfNodeConfigPtr> todo_;
|
||||
AnalysisConfigResultCache cache_;
|
||||
AnalysisConfigAsyncResultCache switch_cache_;
|
||||
AnalysisConfigAsyncResultCache switch_cache_for_check_;
|
||||
|
@ -441,7 +430,7 @@ class AnalysisResultCacheMgr {
|
|||
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::GetThreadID() + ":"; }
|
||||
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::thread_id() + ":"; }
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
|
||||
|
|
|
@ -226,8 +226,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
MS_EXCEPTION_IF_NULL(parent_context_);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
|
||||
if (parent_context_->func_graph() != nullptr) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::GetThreadID() << ":"
|
||||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::GetThreadID() << ":"
|
||||
MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::thread_id() << ":"
|
||||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::thread_id() << ":"
|
||||
<< fg->ToString() << "();";
|
||||
}
|
||||
|
||||
|
|
|
@ -208,8 +208,10 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
eval_result = EvalCNode(cnode, conf);
|
||||
trace::TraceEvalCNodeLeave();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name()
|
||||
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph");
|
||||
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString()
|
||||
<< "(type:" << node->type_name()
|
||||
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
|
||||
<< " conf: " << conf->ToString();
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
|
@ -729,6 +731,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
|||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
|
@ -745,25 +748,25 @@ bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
|||
}
|
||||
|
||||
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
||||
const std::string &threadID, AsyncAbstractPtr async_result_branch,
|
||||
AsyncAbstractPtr async_result_main, AsyncInferTaskPtr async_run_flag,
|
||||
const std::string &thread_id, AsyncAbstractPtr async_result_branch,
|
||||
AsyncAbstractPtr async_result_main, AsyncInferTaskPtr async_task,
|
||||
const trace::TraceGraphEvalStack &graph_evals,
|
||||
const trace::TraceCNodeEvalStack &trace_c_node_evals) {
|
||||
AnalysisSchedule::SetThreadID(threadID);
|
||||
AnalysisSchedule::set_thread_id(thread_id);
|
||||
// Restore trace stack for dump stack when there is exception.
|
||||
trace::TraceEvalCNodeStackPrepare(trace_c_node_evals);
|
||||
trace::TraceGraphEvalStackPrepare(graph_evals);
|
||||
|
||||
try {
|
||||
// Wait for Signal to run
|
||||
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting.";
|
||||
(void)async_run_flag->GetResult();
|
||||
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running.";
|
||||
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " waiting.";
|
||||
(void)async_task->GetResult();
|
||||
MS_LOG(DEBUG) << async_task.get() << " " << eval->ToString() << " running.";
|
||||
|
||||
// Acquire GIL for eval to callback python.
|
||||
EvalResultPtr result;
|
||||
{
|
||||
py::gil_scoped_acquire pyGuard;
|
||||
py::gil_scoped_acquire py_guard;
|
||||
result = eval->Run(engine, args_conf_list, out_conf);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
@ -772,31 +775,21 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
// Check the branch value to be compatible with the other branch value.
|
||||
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
|
||||
// Broaden the result of switch(c,t,f)()
|
||||
auto broadAbstract = result->abstract()->Broaden();
|
||||
auto broaden_abstract = result->abstract()->Broaden();
|
||||
// Notify the thread of waiting for branch value and the main thread to continue.
|
||||
async_result_branch->SetResult(broadAbstract);
|
||||
async_result_main->SetResult(broadAbstract);
|
||||
// Thread number will be drop when thread exits.
|
||||
AnalysisSchedule::GetInstance().DecreaseThreadCount();
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
|
||||
async_result_branch->set_result(broaden_abstract);
|
||||
async_result_main->set_result(broaden_abstract);
|
||||
MS_LOG(DEBUG) << GetInferThread() << " async :" << eval->ToString()
|
||||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
} catch (const std::exception &e1) {
|
||||
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
|
||||
async_result_main->SetResult(abstractErrPtr);
|
||||
} catch (const std::exception &ex) {
|
||||
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
AnalysisSchedule::GetInstance().HandleException(e1);
|
||||
try {
|
||||
// Thread number will be drop when thread exits.
|
||||
AnalysisSchedule::GetInstance().DecreaseThreadCount();
|
||||
} catch (const std::exception &e2) {
|
||||
MS_LOG(DEBUG) << "AnalysisSchedule::GetInstance().DecreaseThreadCount() threw exception.";
|
||||
}
|
||||
AnalysisSchedule::GetInstance().HandleException(ex);
|
||||
}
|
||||
// Thread number will be drop when thread exits.
|
||||
AnalysisSchedule::GetInstance().DecreaseThreadCount();
|
||||
}
|
||||
|
||||
namespace {
|
||||
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
|
||||
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
|
||||
AbstractBasePtrList *out_specs) {
|
||||
|
@ -871,49 +864,49 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
|
||||
// Eval result of the main.
|
||||
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr async_result_main = std::make_shared<AsyncAbstract>();
|
||||
// Eval result of the branches
|
||||
std::vector<AsyncAbstractPtr> branchAsyncResults;
|
||||
std::vector<AsyncAbstractPtr> async_result_branches;
|
||||
|
||||
for (auto &evaluator : evaluators) {
|
||||
static std::atomic<int> idCount{0};
|
||||
std::string threadId = AnalysisSchedule::GetThreadID() + "." + std::to_string(idCount.fetch_add(1));
|
||||
static std::atomic<int> id_count{0};
|
||||
std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
SetUndeterminedFlag(evaluator, possible_parent_fg);
|
||||
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>();
|
||||
// Control the order to run.
|
||||
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
|
||||
AsyncInferTaskPtr asyncTask = AsyncInferTask::MakeShared(asyncRunOrder, threadId);
|
||||
// Add point to the async thread.
|
||||
AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
|
||||
control_run_order->set_result(std::make_shared<AbstractScalar>(1));
|
||||
AsyncInferTaskPtr async_task = AsyncInferTask::MakeShared(control_run_order, thread_id);
|
||||
|
||||
AnalysisSchedule::GetInstance().IncreaseThreadCount();
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
|
||||
auto thread =
|
||||
std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId, branchAsyncResult,
|
||||
asyncResult_main, asyncTask, trace::GetCurrentGraphEvalStack(), trace::GetCNodeDebugStack());
|
||||
auto thread = std::thread(ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, thread_id,
|
||||
async_result_branch, async_result_main, async_task, trace::GetCurrentGraphEvalStack(),
|
||||
trace::GetCNodeDebugStack());
|
||||
thread.detach();
|
||||
// Push to list of running loop
|
||||
asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));
|
||||
MS_LOG(DEBUG) << " add to schedule: " << asyncTask.get();
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(asyncTask); // Activate order witch child thread.
|
||||
(void)branchAsyncResults.emplace_back(std::move(branchAsyncResult));
|
||||
MS_LOG(DEBUG) << " add to schedule: " << async_task.get();
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(async_task); // Activate order witch child thread.
|
||||
(void)async_result_branches.emplace_back(std::move(async_result_branch));
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
||||
<< " or " << evaluators[1]->ToString() << "...";
|
||||
|
||||
auto firstResult = asyncResult_main->GetResult();
|
||||
MS_EXCEPTION_IF_NULL(firstResult);
|
||||
auto first_result = async_result_main->GetResult();
|
||||
MS_EXCEPTION_IF_NULL(first_result);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
||||
<< firstResult->ToString();
|
||||
<< first_result->ToString();
|
||||
|
||||
AbstractBasePtrList out_specs;
|
||||
size_t len = evaluators.size();
|
||||
if (NeedWaitForBranches(firstResult)) {
|
||||
BuildPossibleSpecs(firstResult, branchAsyncResults, &out_specs);
|
||||
if (NeedWaitForBranches(first_result)) {
|
||||
BuildPossibleSpecs(first_result, async_result_branches, &out_specs);
|
||||
} else {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
// Not wait to get the result of branch.
|
||||
auto result = branchAsyncResults[i]->TryGetResult();
|
||||
auto result = async_result_branches[i]->TryGetResult();
|
||||
if (result) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
|
||||
<< " result: " << result->ToString();
|
||||
|
|
|
@ -133,7 +133,7 @@ class AnfNodeConfig : public Config {
|
|||
|
||||
AnalysisEnginePtr engine() const { return engine_.lock(); }
|
||||
|
||||
size_t hash() const {
|
||||
size_t hash() const override {
|
||||
std::size_t node_hash = PointerHash<AnfNodePtr>{}(node_);
|
||||
if (context_->IsDummyContext()) {
|
||||
return node_hash;
|
||||
|
|
Loading…
Reference in New Issue