optimize the infer schedule

This commit is contained in:
lanzhineng 2021-09-04 02:30:41 +08:00
parent 990d561ed9
commit 868553aa92
3 changed files with 172 additions and 112 deletions

View File

@ -1438,6 +1438,7 @@ void ClearResAtexit() {
parse::python_adapter::ResetPythonScope(); parse::python_adapter::ResetPythonScope();
abstract::AnalysisResultCacheMgr::GetInstance().Clear(); abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext(); abstract::AnalysisContext::ClearContext();
abstract::AnalysisSchedule::GetInstance().Stop();
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
Debugger::GetInstance()->Reset(); Debugger::GetInstance()->Reset();
#endif #endif

View File

@ -25,6 +25,26 @@ namespace mindspore {
namespace abstract { namespace abstract {
AnalysisSchedule AnalysisSchedule::instance_; AnalysisSchedule AnalysisSchedule::instance_;
void AnalysisSchedule::Schedule() {
const auto checkPeriod = std::chrono::seconds(3);
std::unique_lock<std::mutex> lock(activate_thread_lock_);
while (notExit_) {
// Check Error
if (StaticAnalysisException::Instance().HasException()) {
// Reset
active_thread_count_.store(1);
} else if (active_thread_count_.load() < 0) {
MS_LOG(ERROR) << "There is something wrong. active thread count: " << active_thread_count_;
}
auto ok = activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() == 0; });
if (ok && (!SetNextReady())) {
// If schedule list is empty, wait.
(void)activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() != 0; });
}
}
}
void AnalysisSchedule::HandleException(const std::exception &ex) { void AnalysisSchedule::HandleException(const std::exception &ex) {
// Just record the first exception information. // Just record the first exception information.
if (!StaticAnalysisException::Instance().HasException()) { if (!StaticAnalysisException::Instance().HasException()) {
@ -34,9 +54,10 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) { if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
try { try {
MS_LOG(DEBUG) << "Python exception happened, check the information as below."; MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
trace::GetTraceStackInfo(exceptionStream_); std::ostringstream exceptionStream;
if (!exceptionStream_.str().empty()) { trace::GetTraceStackInfo(exceptionStream);
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream_.str(); if (!exceptionStream.str().empty()) {
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
} }
} catch (const std::exception &e) { } catch (const std::exception &e) {
// Ignored. // Ignored.
@ -44,24 +65,22 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
} }
} }
// Free all the locks. Let all the threads continue to run. // Free all the locks. Let all the threads continue to run.
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<std::mutex> lock(activate_thread_lock_);
for (auto &item : asyncAbstractList_) { for (auto &item : scheduleList_) {
item->SetRunnable(); item->SetException();
} }
asyncAbstractList_.clear(); scheduleList_.clear();
} }
void AnalysisSchedule::Wait() { void AnalysisSchedule::Wait() {
py::gil_scoped_release infer_gil_release; EnterWaiting();
try {
EnterWaiting();
} catch (const std::exception &ex) {
MS_LOG(DEBUG) << ex.what();
HandleException(ex);
}
{ {
std::unique_lock<std::mutex> lock(lock_); py::gil_scoped_release infer_gil_release;
condition_var_.wait(lock, [this] { return threadNum_ <= 0; }); std::unique_lock<std::mutex> lock(infer_thread_lock_);
infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
}
if (infer_thread_count_.load() < 0) {
MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
} }
LeaveWaiting(); LeaveWaiting();
if (IS_OUTPUT_ON(DEBUG)) { if (IS_OUTPUT_ON(DEBUG)) {
@ -71,30 +90,42 @@ void AnalysisSchedule::Wait() {
StaticAnalysisException::Instance().CheckException(); StaticAnalysisException::Instance().CheckException();
} }
void AnalysisSchedule::SetNextRunnableImpl() { bool AnalysisSchedule::SetNextReady() {
if (asyncAbstractList_.empty()) { if (scheduleList_.empty()) {
MS_LOG(DEBUG) << "The Health List is empty. "; MS_LOG(DEBUG) << "The schedule list is empty. ";
return; return false;
} }
// Check if enter endless loop // Check if enter endless loop
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(), [](const auto &item) { auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
return item->HasResult(); return item->HasResult();
}); });
if (it == asyncAbstractList_.end()) { if (it == scheduleList_.end()) {
// Add activate thread count.
activeThreadCount_++;
// Enter endless loop if there is not ready result. // Enter endless loop if there is not ready result.
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code."; active_thread_count_.fetch_add(1);
// Let the first thread to trigger endless loop exception.
MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
<< scheduleList_.front().get() << " The active thread count: " << active_thread_count_;
scheduleList_.front()->SetEndLessLoopException();
scheduleList_.pop_front();
return true;
} }
// Push back the not ready async.
(void)asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it);
(void)asyncAbstractList_.erase(asyncAbstractList_.begin(), it);
MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The active thread count is " << activeThreadCount_ // Push back the not ready async.
<< " Called times: " << asyncAbstractList_.front()->count(); MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_
asyncAbstractList_.front()->SetRunnable(); << " Before assign, schedule list size: " << scheduleList_.size();
asyncAbstractList_.pop_front(); (void)scheduleList_.insert(scheduleList_.end(), scheduleList_.begin(), it);
(void)scheduleList_.erase(scheduleList_.begin(), it);
active_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << scheduleList_.front().get() << " The active thread count: " << active_thread_count_
<< " Called times: " << scheduleList_.front()->count();
scheduleList_.front()->SetReady();
scheduleList_.pop_front();
MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_
<< " Success to SetNext, schedule list size: " << scheduleList_.size();
return true;
} }
// The thread id format is XXXX.YYYY.ZZZZ // The thread id format is XXXX.YYYY.ZZZZ
thread_local std::string localThreadID; thread_local std::string localThreadID;

View File

@ -28,6 +28,7 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <fstream> #include <fstream>
#include <chrono>
#include "pipeline/jit/static_analysis/static_analysis.h" #include "pipeline/jit/static_analysis/static_analysis.h"
@ -38,82 +39,87 @@ class AsyncAbstract;
using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>; using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
class AnalysisSchedule { class AnalysisSchedule {
public: public:
~AnalysisSchedule() = default; ~AnalysisSchedule() { Stop(); }
AnalysisSchedule(const AnalysisSchedule &) = delete; AnalysisSchedule(const AnalysisSchedule &) = delete;
AnalysisSchedule &operator=(const AnalysisSchedule &) = delete; AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
static AnalysisSchedule &GetInstance() { return instance_; } static AnalysisSchedule &GetInstance() { return instance_; }
static void SetThreadID(const std::string &caller); static void SetThreadID(const std::string &caller);
static std::string &GetThreadID(); static std::string &GetThreadID();
void HandleException(const std::exception &ex); void HandleException(const std::exception &ex);
std::string GetExtendException() { return exceptionStream_.str(); } void Stop() { notExit_ = false; }
void Wait(); void Wait();
void Reset() { void Reset() {
activeThreadCount_ = 1; active_thread_count_.store(1);
threadNum_ = 0; infer_thread_count_.store(0);
exceptionStream_.clear();
}
void SetNextRunnable() {
std::lock_guard<std::mutex> lock(lock_);
SetNextRunnableImpl();
}
void Check() {
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
if (activeThreadCount_ == 0) {
SetNextRunnableImpl();
} else if (activeThreadCount_ < 0) {
MS_LOG(ERROR) << "There is something wrong. active thread count: " << activeThreadCount_;
}
} }
void EnterWaiting() { void EnterWaiting() {
std::lock_guard<std::mutex> lock(lock_); {
--activeThreadCount_; std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_; active_thread_count_.fetch_sub(1);
Check(); MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
} }
void LeaveWaiting() { void LeaveWaiting() {
std::lock_guard<std::mutex> lock(lock_); {
++activeThreadCount_; std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_; active_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
} }
void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) { void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<std::mutex> lock(activate_thread_lock_);
asyncAbstractList_.push_back(asyncAbastract); MS_LOG(DEBUG) << " push async:" << asyncAbastract.get() << " schedule list size:" << scheduleList_.size();
scheduleList_.push_back(asyncAbastract);
} }
void IncreaseThreadCount() { void IncreaseThreadCount() {
std::lock_guard<std::mutex> lock(lock_); infer_thread_count_.fetch_add(1);
++threadNum_; {
++activeThreadCount_; std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_; active_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
} }
void DecreaseThreadCount() { void DecreaseThreadCount() {
{ {
std::lock_guard<std::mutex> threadNumLock(lock_); std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
--threadNum_; infer_thread_count_.fetch_sub(1);
} }
condition_var_.notify_one(); infer_thread_cv_.notify_one();
std::lock_guard<std::mutex> activeLock(lock_); {
--activeThreadCount_; std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_; active_thread_count_.fetch_sub(1);
Check(); MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
} }
private: private:
void SetNextRunnableImpl(); void Schedule();
AnalysisSchedule() = default; bool SetNextReady();
void Start() {
auto thread = std::thread([this] { Schedule(); });
thread.detach();
}
AnalysisSchedule() { Start(); }
static AnalysisSchedule instance_; static AnalysisSchedule instance_;
int activeThreadCount_{1}; std::atomic<int> active_thread_count_{1};
int threadNum_{0}; std::atomic<int> infer_thread_count_{0};
std::mutex lock_; bool notExit_{true};
std::condition_variable condition_var_; std::mutex infer_thread_lock_;
std::list<AsyncAbstractPtr> asyncAbstractList_; std::condition_variable infer_thread_cv_;
std::ostringstream exceptionStream_; std::mutex activate_thread_lock_;
std::condition_variable activate_thread_cv_;
std::list<AsyncAbstractPtr> scheduleList_;
}; };
template <typename KeyType, typename ValueType, typename CacheType> template <typename KeyType, typename ValueType, typename CacheType>
@ -216,57 +222,79 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
~AsyncAbstract() = default; ~AsyncAbstract() = default;
// Wait // Wait
AbstractBasePtr GetResult() { AbstractBasePtr GetResult() {
StaticAnalysisException::Instance().CheckException(); MS_LOG(DEBUG) << this << " begin GetResult.";
std::unique_lock<std::mutex> lock(lock_);
while (true) { while (true) {
++count_; // Enter waiting ,and let the other thread to run
// The active thread count should be dropped if it can't run. It will be added when it can run. MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0); if (!ready_) {
bool hasEnterWaiting = false;
if (!runnable_) {
AnalysisSchedule::GetInstance().EnterWaiting(); AnalysisSchedule::GetInstance().EnterWaiting();
hasEnterWaiting = true;
} }
MS_LOG(DEBUG) << this << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0); condition_var_.wait(lock, [this] { return ready_; });
{ ClearReady(); // Clear nomal ready flag
std::unique_lock<std::mutex> lock(lock_); MS_LOG(DEBUG) << this << " can go: " << ready_ << " result: " << (result_ ? result_.get() : 0);
condition_var_.wait(lock, [this] { return runnable_; }); HandleEndLessLoopException();
}
if (hasEnterWaiting) {
AnalysisSchedule::GetInstance().LeaveWaiting();
}
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
StaticAnalysisException::Instance().CheckException(); StaticAnalysisException::Instance().CheckException();
SetUnrunnable();
if (result_ != nullptr) { if (result_ != nullptr) {
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0); MS_LOG(DEBUG) << this << " Success to GetResult. Return result: " << (result_ ? result_.get() : 0);
return result_; return result_;
} }
// Push to list // wait for result until it is not null.
++count_;
AnalysisSchedule::GetInstance().Add2Schedule(shared_from_this()); AnalysisSchedule::GetInstance().Add2Schedule(shared_from_this());
// Notify the next asyncAbastract to run. MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0)
AnalysisSchedule::GetInstance().SetNextRunnable(); << " Enter schedule list to wait.";
MS_LOG(DEBUG) << this << " SetNextRunnable "
<< " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
} }
} }
void SetRunnable() { void SetReady() {
MS_LOG(DEBUG) << this << " Runnable."; MS_LOG(DEBUG) << this << " want to set ready.";
{ {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<std::mutex> lock(lock_);
runnable_ = true; ready_ = ready_ | 1; // Set the first bit = 1
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
} }
condition_var_.notify_one(); condition_var_.notify_one();
} }
void SetUnrunnable() {
std::lock_guard<std::mutex> lock(lock_); void SetException() {
runnable_ = false; MS_LOG(DEBUG) << this << " want to set ready.";
{
std::lock_guard<std::mutex> lock(lock_);
ready_ = ready_ | 2; // Set the second bit = 1
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
condition_var_.notify_one();
}
void SetEndLessLoopException() {
MS_LOG(DEBUG) << this << " want to set ready.";
{
std::lock_guard<std::mutex> lock(lock_);
ready_ = ready_ | 4; // Set the third bit = 1
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
condition_var_.notify_one();
}
void ClearReady() {
ready_ = ready_ & 6; // Set first bit = 0
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
void HandleEndLessLoopException() {
// Get third bit
if (ready_ & 4) {
ready_ = ready_ & 3; // Set the third bit = 0 , Only trigger once.
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
}
} }
int count() const { return count_; } int count() const { return count_; }
bool HasResult() {
bool HasResult() { return result_ != nullptr; } std::lock_guard<std::mutex> lock(lock_);
return result_ != nullptr;
}
// Not wait // Not wait
AbstractBasePtr TryGetResult() { AbstractBasePtr TryGetResult() {
std::lock_guard<std::mutex> lock(lock_); std::lock_guard<std::mutex> lock(lock_);
@ -287,7 +315,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
private: private:
std::mutex lock_; std::mutex lock_;
std::condition_variable condition_var_; std::condition_variable condition_var_;
bool runnable_{false}; int ready_{0}; // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
int count_{0}; int count_{0};
AbstractBasePtr result_{nullptr}; AbstractBasePtr result_{nullptr};
}; };