optimize the infer schedule

This commit is contained in:
lanzhineng 2021-09-04 02:30:41 +08:00
parent 990d561ed9
commit 868553aa92
3 changed files with 172 additions and 112 deletions

View File

@ -1438,6 +1438,7 @@ void ClearResAtexit() {
parse::python_adapter::ResetPythonScope();
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
abstract::AnalysisSchedule::GetInstance().Stop();
#ifdef ENABLE_DEBUGGER
Debugger::GetInstance()->Reset();
#endif

View File

@ -25,6 +25,26 @@ namespace mindspore {
namespace abstract {
AnalysisSchedule AnalysisSchedule::instance_;
void AnalysisSchedule::Schedule() {
const auto checkPeriod = std::chrono::seconds(3);
std::unique_lock<std::mutex> lock(activate_thread_lock_);
while (notExit_) {
// Check Error
if (StaticAnalysisException::Instance().HasException()) {
// Reset
active_thread_count_.store(1);
} else if (active_thread_count_.load() < 0) {
MS_LOG(ERROR) << "There is something wrong. active thread count: " << active_thread_count_;
}
auto ok = activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() == 0; });
if (ok && (!SetNextReady())) {
// If schedule list is empty, wait.
(void)activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() != 0; });
}
}
}
void AnalysisSchedule::HandleException(const std::exception &ex) {
// Just record the first exception information.
if (!StaticAnalysisException::Instance().HasException()) {
@ -34,9 +54,10 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
try {
MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
trace::GetTraceStackInfo(exceptionStream_);
if (!exceptionStream_.str().empty()) {
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream_.str();
std::ostringstream exceptionStream;
trace::GetTraceStackInfo(exceptionStream);
if (!exceptionStream.str().empty()) {
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
}
} catch (const std::exception &e) {
// Ignored.
@ -44,24 +65,22 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
}
}
// Free all the locks. Let all the threads continue to run.
std::lock_guard<std::mutex> lock(lock_);
for (auto &item : asyncAbstractList_) {
item->SetRunnable();
std::lock_guard<std::mutex> lock(activate_thread_lock_);
for (auto &item : scheduleList_) {
item->SetException();
}
asyncAbstractList_.clear();
scheduleList_.clear();
}
void AnalysisSchedule::Wait() {
py::gil_scoped_release infer_gil_release;
try {
EnterWaiting();
} catch (const std::exception &ex) {
MS_LOG(DEBUG) << ex.what();
HandleException(ex);
}
EnterWaiting();
{
std::unique_lock<std::mutex> lock(lock_);
condition_var_.wait(lock, [this] { return threadNum_ <= 0; });
py::gil_scoped_release infer_gil_release;
std::unique_lock<std::mutex> lock(infer_thread_lock_);
infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
}
if (infer_thread_count_.load() < 0) {
MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
}
LeaveWaiting();
if (IS_OUTPUT_ON(DEBUG)) {
@ -71,30 +90,42 @@ void AnalysisSchedule::Wait() {
StaticAnalysisException::Instance().CheckException();
}
void AnalysisSchedule::SetNextRunnableImpl() {
if (asyncAbstractList_.empty()) {
MS_LOG(DEBUG) << "The Health List is empty. ";
return;
bool AnalysisSchedule::SetNextReady() {
if (scheduleList_.empty()) {
MS_LOG(DEBUG) << "The schedule list is empty. ";
return false;
}
// Check if enter endless loop
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(), [](const auto &item) {
auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
MS_EXCEPTION_IF_NULL(item);
return item->HasResult();
});
if (it == asyncAbstractList_.end()) {
// Add activate thread count.
activeThreadCount_++;
if (it == scheduleList_.end()) {
// Enter endless loop if there is not ready result.
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
active_thread_count_.fetch_add(1);
// Let the first thread to trigger endless loop exception.
MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
<< scheduleList_.front().get() << " The active thread count: " << active_thread_count_;
scheduleList_.front()->SetEndLessLoopException();
scheduleList_.pop_front();
return true;
}
// Push back the not ready async.
(void)asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it);
(void)asyncAbstractList_.erase(asyncAbstractList_.begin(), it);
MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The active thread count is " << activeThreadCount_
<< " Called times: " << asyncAbstractList_.front()->count();
asyncAbstractList_.front()->SetRunnable();
asyncAbstractList_.pop_front();
// Push back the not ready async.
MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_
<< " Before assign, schedule list size: " << scheduleList_.size();
(void)scheduleList_.insert(scheduleList_.end(), scheduleList_.begin(), it);
(void)scheduleList_.erase(scheduleList_.begin(), it);
active_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << scheduleList_.front().get() << " The active thread count: " << active_thread_count_
<< " Called times: " << scheduleList_.front()->count();
scheduleList_.front()->SetReady();
scheduleList_.pop_front();
MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_
<< " Success to SetNext, schedule list size: " << scheduleList_.size();
return true;
}
// The thread id format is XXXX.YYYY.ZZZZ
thread_local std::string localThreadID;

View File

@ -28,6 +28,7 @@
#include <functional>
#include <list>
#include <fstream>
#include <chrono>
#include "pipeline/jit/static_analysis/static_analysis.h"
@ -38,82 +39,87 @@ class AsyncAbstract;
using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
class AnalysisSchedule {
public:
~AnalysisSchedule() = default;
~AnalysisSchedule() { Stop(); }
AnalysisSchedule(const AnalysisSchedule &) = delete;
AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
static AnalysisSchedule &GetInstance() { return instance_; }
static void SetThreadID(const std::string &caller);
static std::string &GetThreadID();
void HandleException(const std::exception &ex);
std::string GetExtendException() { return exceptionStream_.str(); }
void Stop() { notExit_ = false; }
void Wait();
void Reset() {
activeThreadCount_ = 1;
threadNum_ = 0;
exceptionStream_.clear();
}
void SetNextRunnable() {
std::lock_guard<std::mutex> lock(lock_);
SetNextRunnableImpl();
}
void Check() {
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
if (activeThreadCount_ == 0) {
SetNextRunnableImpl();
} else if (activeThreadCount_ < 0) {
MS_LOG(ERROR) << "There is something wrong. active thread count: " << activeThreadCount_;
}
active_thread_count_.store(1);
infer_thread_count_.store(0);
}
void EnterWaiting() {
std::lock_guard<std::mutex> lock(lock_);
--activeThreadCount_;
MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_;
Check();
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
active_thread_count_.fetch_sub(1);
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
}
void LeaveWaiting() {
std::lock_guard<std::mutex> lock(lock_);
++activeThreadCount_;
MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_;
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
active_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
}
void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) {
std::lock_guard<std::mutex> lock(lock_);
asyncAbstractList_.push_back(asyncAbastract);
std::lock_guard<std::mutex> lock(activate_thread_lock_);
MS_LOG(DEBUG) << " push async:" << asyncAbastract.get() << " schedule list size:" << scheduleList_.size();
scheduleList_.push_back(asyncAbastract);
}
void IncreaseThreadCount() {
std::lock_guard<std::mutex> lock(lock_);
++threadNum_;
++activeThreadCount_;
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
infer_thread_count_.fetch_add(1);
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
active_thread_count_.fetch_add(1);
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
}
void DecreaseThreadCount() {
{
std::lock_guard<std::mutex> threadNumLock(lock_);
--threadNum_;
std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
infer_thread_count_.fetch_sub(1);
}
condition_var_.notify_one();
infer_thread_cv_.notify_one();
std::lock_guard<std::mutex> activeLock(lock_);
--activeThreadCount_;
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
Check();
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
active_thread_count_.fetch_sub(1);
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
}
activate_thread_cv_.notify_one();
}
private:
void SetNextRunnableImpl();
AnalysisSchedule() = default;
void Schedule();
bool SetNextReady();
void Start() {
auto thread = std::thread([this] { Schedule(); });
thread.detach();
}
AnalysisSchedule() { Start(); }
static AnalysisSchedule instance_;
int activeThreadCount_{1};
int threadNum_{0};
std::mutex lock_;
std::condition_variable condition_var_;
std::list<AsyncAbstractPtr> asyncAbstractList_;
std::ostringstream exceptionStream_;
std::atomic<int> active_thread_count_{1};
std::atomic<int> infer_thread_count_{0};
bool notExit_{true};
std::mutex infer_thread_lock_;
std::condition_variable infer_thread_cv_;
std::mutex activate_thread_lock_;
std::condition_variable activate_thread_cv_;
std::list<AsyncAbstractPtr> scheduleList_;
};
template <typename KeyType, typename ValueType, typename CacheType>
@ -216,57 +222,79 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
~AsyncAbstract() = default;
// Wait
AbstractBasePtr GetResult() {
StaticAnalysisException::Instance().CheckException();
MS_LOG(DEBUG) << this << " begin GetResult.";
std::unique_lock<std::mutex> lock(lock_);
while (true) {
++count_;
// The active thread count should be dropped if it can't run. It will be added when it can run.
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
bool hasEnterWaiting = false;
if (!runnable_) {
// Enter waiting ,and let the other thread to run
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
if (!ready_) {
AnalysisSchedule::GetInstance().EnterWaiting();
hasEnterWaiting = true;
}
MS_LOG(DEBUG) << this << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
{
std::unique_lock<std::mutex> lock(lock_);
condition_var_.wait(lock, [this] { return runnable_; });
}
if (hasEnterWaiting) {
AnalysisSchedule::GetInstance().LeaveWaiting();
}
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
condition_var_.wait(lock, [this] { return ready_; });
ClearReady(); // Clear nomal ready flag
MS_LOG(DEBUG) << this << " can go: " << ready_ << " result: " << (result_ ? result_.get() : 0);
HandleEndLessLoopException();
StaticAnalysisException::Instance().CheckException();
SetUnrunnable();
if (result_ != nullptr) {
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0);
MS_LOG(DEBUG) << this << " Success to GetResult. Return result: " << (result_ ? result_.get() : 0);
return result_;
}
// Push to list
// wait for result until it is not null.
++count_;
AnalysisSchedule::GetInstance().Add2Schedule(shared_from_this());
// Notify the next asyncAbastract to run.
AnalysisSchedule::GetInstance().SetNextRunnable();
MS_LOG(DEBUG) << this << " SetNextRunnable "
<< " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0)
<< " Enter schedule list to wait.";
}
}
void SetRunnable() {
MS_LOG(DEBUG) << this << " Runnable.";
void SetReady() {
MS_LOG(DEBUG) << this << " want to set ready.";
{
std::lock_guard<std::mutex> lock(lock_);
runnable_ = true;
ready_ = ready_ | 1; // Set the first bit = 1
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
condition_var_.notify_one();
}
void SetUnrunnable() {
std::lock_guard<std::mutex> lock(lock_);
runnable_ = false;
void SetException() {
MS_LOG(DEBUG) << this << " want to set ready.";
{
std::lock_guard<std::mutex> lock(lock_);
ready_ = ready_ | 2; // Set the second bit = 1
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
condition_var_.notify_one();
}
void SetEndLessLoopException() {
MS_LOG(DEBUG) << this << " want to set ready.";
{
std::lock_guard<std::mutex> lock(lock_);
ready_ = ready_ | 4; // Set the third bit = 1
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
condition_var_.notify_one();
}
void ClearReady() {
ready_ = ready_ & 6; // Set first bit = 0
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
}
void HandleEndLessLoopException() {
// Get third bit
if (ready_ & 4) {
ready_ = ready_ & 3; // Set the third bit = 0 , Only trigger once.
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
}
}
int count() const { return count_; }
bool HasResult() { return result_ != nullptr; }
bool HasResult() {
std::lock_guard<std::mutex> lock(lock_);
return result_ != nullptr;
}
// Not wait
AbstractBasePtr TryGetResult() {
std::lock_guard<std::mutex> lock(lock_);
@ -287,7 +315,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
private:
std::mutex lock_;
std::condition_variable condition_var_;
bool runnable_{false};
int ready_{0}; // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
int count_{0};
AbstractBasePtr result_{nullptr};
};