forked from mindspore-Ecosystem/mindspore
optimize the infer schedule
This commit is contained in:
parent
990d561ed9
commit
868553aa92
|
@ -1438,6 +1438,7 @@ void ClearResAtexit() {
|
|||
parse::python_adapter::ResetPythonScope();
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
abstract::AnalysisContext::ClearContext();
|
||||
abstract::AnalysisSchedule::GetInstance().Stop();
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
Debugger::GetInstance()->Reset();
|
||||
#endif
|
||||
|
|
|
@ -25,6 +25,26 @@ namespace mindspore {
|
|||
namespace abstract {
|
||||
AnalysisSchedule AnalysisSchedule::instance_;
|
||||
|
||||
void AnalysisSchedule::Schedule() {
|
||||
const auto checkPeriod = std::chrono::seconds(3);
|
||||
std::unique_lock<std::mutex> lock(activate_thread_lock_);
|
||||
while (notExit_) {
|
||||
// Check Error
|
||||
if (StaticAnalysisException::Instance().HasException()) {
|
||||
// Reset
|
||||
active_thread_count_.store(1);
|
||||
} else if (active_thread_count_.load() < 0) {
|
||||
MS_LOG(ERROR) << "There is something wrong. active thread count: " << active_thread_count_;
|
||||
}
|
||||
|
||||
auto ok = activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() == 0; });
|
||||
if (ok && (!SetNextReady())) {
|
||||
// If schedule list is empty, wait.
|
||||
(void)activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() != 0; });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AnalysisSchedule::HandleException(const std::exception &ex) {
|
||||
// Just record the first exception information.
|
||||
if (!StaticAnalysisException::Instance().HasException()) {
|
||||
|
@ -34,9 +54,10 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
|
|||
if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
|
||||
try {
|
||||
MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
|
||||
trace::GetTraceStackInfo(exceptionStream_);
|
||||
if (!exceptionStream_.str().empty()) {
|
||||
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream_.str();
|
||||
std::ostringstream exceptionStream;
|
||||
trace::GetTraceStackInfo(exceptionStream);
|
||||
if (!exceptionStream.str().empty()) {
|
||||
MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
// Ignored.
|
||||
|
@ -44,24 +65,22 @@ void AnalysisSchedule::HandleException(const std::exception &ex) {
|
|||
}
|
||||
}
|
||||
// Free all the locks. Let all the threads continue to run.
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
for (auto &item : asyncAbstractList_) {
|
||||
item->SetRunnable();
|
||||
std::lock_guard<std::mutex> lock(activate_thread_lock_);
|
||||
for (auto &item : scheduleList_) {
|
||||
item->SetException();
|
||||
}
|
||||
asyncAbstractList_.clear();
|
||||
scheduleList_.clear();
|
||||
}
|
||||
|
||||
void AnalysisSchedule::Wait() {
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
try {
|
||||
EnterWaiting();
|
||||
} catch (const std::exception &ex) {
|
||||
MS_LOG(DEBUG) << ex.what();
|
||||
HandleException(ex);
|
||||
}
|
||||
EnterWaiting();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
condition_var_.wait(lock, [this] { return threadNum_ <= 0; });
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
std::unique_lock<std::mutex> lock(infer_thread_lock_);
|
||||
infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
|
||||
}
|
||||
if (infer_thread_count_.load() < 0) {
|
||||
MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
|
||||
}
|
||||
LeaveWaiting();
|
||||
if (IS_OUTPUT_ON(DEBUG)) {
|
||||
|
@ -71,30 +90,42 @@ void AnalysisSchedule::Wait() {
|
|||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
|
||||
void AnalysisSchedule::SetNextRunnableImpl() {
|
||||
if (asyncAbstractList_.empty()) {
|
||||
MS_LOG(DEBUG) << "The Health List is empty. ";
|
||||
return;
|
||||
bool AnalysisSchedule::SetNextReady() {
|
||||
if (scheduleList_.empty()) {
|
||||
MS_LOG(DEBUG) << "The schedule list is empty. ";
|
||||
return false;
|
||||
}
|
||||
// Check if enter endless loop
|
||||
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(), [](const auto &item) {
|
||||
auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return item->HasResult();
|
||||
});
|
||||
if (it == asyncAbstractList_.end()) {
|
||||
// Add activate thread count.
|
||||
activeThreadCount_++;
|
||||
if (it == scheduleList_.end()) {
|
||||
// Enter endless loop if there is not ready result.
|
||||
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
|
||||
active_thread_count_.fetch_add(1);
|
||||
// Let the first thread to trigger endless loop exception.
|
||||
MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
|
||||
<< scheduleList_.front().get() << " The active thread count: " << active_thread_count_;
|
||||
scheduleList_.front()->SetEndLessLoopException();
|
||||
scheduleList_.pop_front();
|
||||
return true;
|
||||
}
|
||||
// Push back the not ready async.
|
||||
(void)asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it);
|
||||
(void)asyncAbstractList_.erase(asyncAbstractList_.begin(), it);
|
||||
|
||||
MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The active thread count is " << activeThreadCount_
|
||||
<< " Called times: " << asyncAbstractList_.front()->count();
|
||||
asyncAbstractList_.front()->SetRunnable();
|
||||
asyncAbstractList_.pop_front();
|
||||
// Push back the not ready async.
|
||||
MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_
|
||||
<< " Before assign, schedule list size: " << scheduleList_.size();
|
||||
(void)scheduleList_.insert(scheduleList_.end(), scheduleList_.begin(), it);
|
||||
(void)scheduleList_.erase(scheduleList_.begin(), it);
|
||||
|
||||
active_thread_count_.fetch_add(1);
|
||||
MS_LOG(DEBUG) << scheduleList_.front().get() << " The active thread count: " << active_thread_count_
|
||||
<< " Called times: " << scheduleList_.front()->count();
|
||||
scheduleList_.front()->SetReady();
|
||||
scheduleList_.pop_front();
|
||||
MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_
|
||||
<< " Success to SetNext, schedule list size: " << scheduleList_.size();
|
||||
|
||||
return true;
|
||||
}
|
||||
// The thread id format is XXXX.YYYY.ZZZZ
|
||||
thread_local std::string localThreadID;
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <functional>
|
||||
#include <list>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
|
||||
|
@ -38,82 +39,87 @@ class AsyncAbstract;
|
|||
using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
|
||||
class AnalysisSchedule {
|
||||
public:
|
||||
~AnalysisSchedule() = default;
|
||||
~AnalysisSchedule() { Stop(); }
|
||||
AnalysisSchedule(const AnalysisSchedule &) = delete;
|
||||
AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
|
||||
static AnalysisSchedule &GetInstance() { return instance_; }
|
||||
static void SetThreadID(const std::string &caller);
|
||||
static std::string &GetThreadID();
|
||||
void HandleException(const std::exception &ex);
|
||||
std::string GetExtendException() { return exceptionStream_.str(); }
|
||||
void Stop() { notExit_ = false; }
|
||||
void Wait();
|
||||
|
||||
void Reset() {
|
||||
activeThreadCount_ = 1;
|
||||
threadNum_ = 0;
|
||||
exceptionStream_.clear();
|
||||
}
|
||||
|
||||
void SetNextRunnable() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
SetNextRunnableImpl();
|
||||
}
|
||||
|
||||
void Check() {
|
||||
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
|
||||
if (activeThreadCount_ == 0) {
|
||||
SetNextRunnableImpl();
|
||||
} else if (activeThreadCount_ < 0) {
|
||||
MS_LOG(ERROR) << "There is something wrong. active thread count: " << activeThreadCount_;
|
||||
}
|
||||
active_thread_count_.store(1);
|
||||
infer_thread_count_.store(0);
|
||||
}
|
||||
|
||||
void EnterWaiting() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
--activeThreadCount_;
|
||||
MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_;
|
||||
Check();
|
||||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
active_thread_count_.fetch_sub(1);
|
||||
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
}
|
||||
|
||||
void LeaveWaiting() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
++activeThreadCount_;
|
||||
MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_;
|
||||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
active_thread_count_.fetch_add(1);
|
||||
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
}
|
||||
|
||||
void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
asyncAbstractList_.push_back(asyncAbastract);
|
||||
std::lock_guard<std::mutex> lock(activate_thread_lock_);
|
||||
MS_LOG(DEBUG) << " push async:" << asyncAbastract.get() << " schedule list size:" << scheduleList_.size();
|
||||
scheduleList_.push_back(asyncAbastract);
|
||||
}
|
||||
|
||||
void IncreaseThreadCount() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
++threadNum_;
|
||||
++activeThreadCount_;
|
||||
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
|
||||
infer_thread_count_.fetch_add(1);
|
||||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
active_thread_count_.fetch_add(1);
|
||||
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
}
|
||||
|
||||
void DecreaseThreadCount() {
|
||||
{
|
||||
std::lock_guard<std::mutex> threadNumLock(lock_);
|
||||
--threadNum_;
|
||||
std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
|
||||
infer_thread_count_.fetch_sub(1);
|
||||
}
|
||||
condition_var_.notify_one();
|
||||
infer_thread_cv_.notify_one();
|
||||
|
||||
std::lock_guard<std::mutex> activeLock(lock_);
|
||||
--activeThreadCount_;
|
||||
MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_;
|
||||
Check();
|
||||
{
|
||||
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
|
||||
active_thread_count_.fetch_sub(1);
|
||||
MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_;
|
||||
}
|
||||
activate_thread_cv_.notify_one();
|
||||
}
|
||||
|
||||
private:
|
||||
void SetNextRunnableImpl();
|
||||
AnalysisSchedule() = default;
|
||||
void Schedule();
|
||||
bool SetNextReady();
|
||||
void Start() {
|
||||
auto thread = std::thread([this] { Schedule(); });
|
||||
thread.detach();
|
||||
}
|
||||
AnalysisSchedule() { Start(); }
|
||||
static AnalysisSchedule instance_;
|
||||
int activeThreadCount_{1};
|
||||
int threadNum_{0};
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
std::list<AsyncAbstractPtr> asyncAbstractList_;
|
||||
std::ostringstream exceptionStream_;
|
||||
std::atomic<int> active_thread_count_{1};
|
||||
std::atomic<int> infer_thread_count_{0};
|
||||
bool notExit_{true};
|
||||
std::mutex infer_thread_lock_;
|
||||
std::condition_variable infer_thread_cv_;
|
||||
std::mutex activate_thread_lock_;
|
||||
std::condition_variable activate_thread_cv_;
|
||||
std::list<AsyncAbstractPtr> scheduleList_;
|
||||
};
|
||||
|
||||
template <typename KeyType, typename ValueType, typename CacheType>
|
||||
|
@ -216,57 +222,79 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
~AsyncAbstract() = default;
|
||||
// Wait
|
||||
AbstractBasePtr GetResult() {
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
MS_LOG(DEBUG) << this << " begin GetResult.";
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
while (true) {
|
||||
++count_;
|
||||
// The active thread count should be dropped if it can't run. It will be added when it can run.
|
||||
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
bool hasEnterWaiting = false;
|
||||
if (!runnable_) {
|
||||
// Enter waiting ,and let the other thread to run
|
||||
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
|
||||
if (!ready_) {
|
||||
AnalysisSchedule::GetInstance().EnterWaiting();
|
||||
hasEnterWaiting = true;
|
||||
}
|
||||
MS_LOG(DEBUG) << this << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
condition_var_.wait(lock, [this] { return runnable_; });
|
||||
}
|
||||
if (hasEnterWaiting) {
|
||||
AnalysisSchedule::GetInstance().LeaveWaiting();
|
||||
}
|
||||
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
|
||||
condition_var_.wait(lock, [this] { return ready_; });
|
||||
ClearReady(); // Clear nomal ready flag
|
||||
MS_LOG(DEBUG) << this << " can go: " << ready_ << " result: " << (result_ ? result_.get() : 0);
|
||||
HandleEndLessLoopException();
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
SetUnrunnable();
|
||||
if (result_ != nullptr) {
|
||||
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0);
|
||||
MS_LOG(DEBUG) << this << " Success to GetResult. Return result: " << (result_ ? result_.get() : 0);
|
||||
return result_;
|
||||
}
|
||||
// Push to list
|
||||
// wait for result until it is not null.
|
||||
++count_;
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(shared_from_this());
|
||||
// Notify the next asyncAbastract to run.
|
||||
AnalysisSchedule::GetInstance().SetNextRunnable();
|
||||
MS_LOG(DEBUG) << this << " SetNextRunnable "
|
||||
<< " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0)
|
||||
<< " Enter schedule list to wait.";
|
||||
}
|
||||
}
|
||||
|
||||
void SetRunnable() {
|
||||
MS_LOG(DEBUG) << this << " Runnable.";
|
||||
void SetReady() {
|
||||
MS_LOG(DEBUG) << this << " want to set ready.";
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
runnable_ = true;
|
||||
ready_ = ready_ | 1; // Set the first bit = 1
|
||||
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
|
||||
}
|
||||
condition_var_.notify_one();
|
||||
}
|
||||
void SetUnrunnable() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
runnable_ = false;
|
||||
|
||||
void SetException() {
|
||||
MS_LOG(DEBUG) << this << " want to set ready.";
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
ready_ = ready_ | 2; // Set the second bit = 1
|
||||
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
|
||||
}
|
||||
condition_var_.notify_one();
|
||||
}
|
||||
|
||||
void SetEndLessLoopException() {
|
||||
MS_LOG(DEBUG) << this << " want to set ready.";
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
ready_ = ready_ | 4; // Set the third bit = 1
|
||||
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
|
||||
}
|
||||
condition_var_.notify_one();
|
||||
}
|
||||
|
||||
void ClearReady() {
|
||||
ready_ = ready_ & 6; // Set first bit = 0
|
||||
MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0);
|
||||
}
|
||||
|
||||
void HandleEndLessLoopException() {
|
||||
// Get third bit
|
||||
if (ready_ & 4) {
|
||||
ready_ = ready_ & 3; // Set the third bit = 0 , Only trigger once.
|
||||
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
|
||||
}
|
||||
}
|
||||
|
||||
int count() const { return count_; }
|
||||
|
||||
bool HasResult() { return result_ != nullptr; }
|
||||
bool HasResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return result_ != nullptr;
|
||||
}
|
||||
// Not wait
|
||||
AbstractBasePtr TryGetResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
|
@ -287,7 +315,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
private:
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
bool runnable_{false};
|
||||
int ready_{0}; // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
|
||||
int count_{0};
|
||||
AbstractBasePtr result_{nullptr};
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue