From 868553aa928a34bfee23eaf7b7717cccebe55359 Mon Sep 17 00:00:00 2001 From: lanzhineng Date: Sat, 4 Sep 2021 02:30:41 +0800 Subject: [PATCH] optimize the infer schedule --- mindspore/ccsrc/pipeline/jit/pipeline.cc | 1 + .../jit/static_analysis/async_eval_result.cc | 95 ++++++--- .../jit/static_analysis/async_eval_result.h | 188 ++++++++++-------- 3 files changed, 172 insertions(+), 112 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 9e1f44436ed..25542b0f5a4 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1438,6 +1438,7 @@ void ClearResAtexit() { parse::python_adapter::ResetPythonScope(); abstract::AnalysisResultCacheMgr::GetInstance().Clear(); abstract::AnalysisContext::ClearContext(); + abstract::AnalysisSchedule::GetInstance().Stop(); #ifdef ENABLE_DEBUGGER Debugger::GetInstance()->Reset(); #endif diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc index 4e03a3a7b06..e7014bda3c1 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc @@ -25,6 +25,26 @@ namespace mindspore { namespace abstract { AnalysisSchedule AnalysisSchedule::instance_; +void AnalysisSchedule::Schedule() { + const auto checkPeriod = std::chrono::seconds(3); + std::unique_lock lock(activate_thread_lock_); + while (notExit_) { + // Check Error + if (StaticAnalysisException::Instance().HasException()) { + // Reset + active_thread_count_.store(1); + } else if (active_thread_count_.load() < 0) { + MS_LOG(ERROR) << "There is something wrong. active thread count: " << active_thread_count_; + } + + auto ok = activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() == 0; }); + if (ok && (!SetNextReady())) { + // If schedule list is empty, wait. + (void)activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return active_thread_count_.load() != 0; }); + } + } +} + void AnalysisSchedule::HandleException(const std::exception &ex) { // Just record the first exception information. if (!StaticAnalysisException::Instance().HasException()) { @@ -34,9 +54,10 @@ void AnalysisSchedule::HandleException(const std::exception &ex) { if (dynamic_cast(&ex) != nullptr) { try { MS_LOG(DEBUG) << "Python exception happened, check the information as below."; - trace::GetTraceStackInfo(exceptionStream_); - if (!exceptionStream_.str().empty()) { - MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream_.str(); + std::ostringstream exceptionStream; + trace::GetTraceStackInfo(exceptionStream); + if (!exceptionStream.str().empty()) { + MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str(); } } catch (const std::exception &e) { // Ignored. @@ -44,24 +65,22 @@ void AnalysisSchedule::HandleException(const std::exception &ex) { } } // Free all the locks. Let all the threads continue to run. - std::lock_guard lock(lock_); - for (auto &item : asyncAbstractList_) { - item->SetRunnable(); + std::lock_guard lock(activate_thread_lock_); + for (auto &item : scheduleList_) { + item->SetException(); } - asyncAbstractList_.clear(); + scheduleList_.clear(); } void AnalysisSchedule::Wait() { - py::gil_scoped_release infer_gil_release; - try { - EnterWaiting(); - } catch (const std::exception &ex) { - MS_LOG(DEBUG) << ex.what(); - HandleException(ex); - } + EnterWaiting(); { - std::unique_lock lock(lock_); - condition_var_.wait(lock, [this] { return threadNum_ <= 0; }); + py::gil_scoped_release infer_gil_release; + std::unique_lock lock(infer_thread_lock_); + infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; }); + } + if (infer_thread_count_.load() < 0) { + MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_; } LeaveWaiting(); if (IS_OUTPUT_ON(DEBUG)) { @@ -71,30 +90,42 @@ void AnalysisSchedule::Wait() { StaticAnalysisException::Instance().CheckException(); } -void AnalysisSchedule::SetNextRunnableImpl() { - if (asyncAbstractList_.empty()) { - MS_LOG(DEBUG) << "The Health List is empty. "; - return; +bool AnalysisSchedule::SetNextReady() { + if (scheduleList_.empty()) { + MS_LOG(DEBUG) << "The schedule list is empty. "; + return false; } // Check if enter endless loop - auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(), [](const auto &item) { + auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) { MS_EXCEPTION_IF_NULL(item); return item->HasResult(); }); - if (it == asyncAbstractList_.end()) { - // Add activate thread count. - activeThreadCount_++; + if (it == scheduleList_.end()) { // Enter endless loop if there is not ready result. - MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code."; + active_thread_count_.fetch_add(1); + // Let the first thread to trigger endless loop exception. + MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:" + << scheduleList_.front().get() << " The active thread count: " << active_thread_count_; + scheduleList_.front()->SetEndLessLoopException(); + scheduleList_.pop_front(); + return true; } - // Push back the not ready async. - (void)asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it); - (void)asyncAbstractList_.erase(asyncAbstractList_.begin(), it); - MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The active thread count is " << activeThreadCount_ - << " Called times: " << asyncAbstractList_.front()->count(); - asyncAbstractList_.front()->SetRunnable(); - asyncAbstractList_.pop_front(); + // Push back the not ready async. + MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_ + << " Before assign, schedule list size: " << scheduleList_.size(); + (void)scheduleList_.insert(scheduleList_.end(), scheduleList_.begin(), it); + (void)scheduleList_.erase(scheduleList_.begin(), it); + + active_thread_count_.fetch_add(1); + MS_LOG(DEBUG) << scheduleList_.front().get() << " The active thread count: " << active_thread_count_ + << " Called times: " << scheduleList_.front()->count(); + scheduleList_.front()->SetReady(); + scheduleList_.pop_front(); + MS_LOG(DEBUG) << " The active thread count: " << active_thread_count_ + << " Success to SetNext, schedule list size: " << scheduleList_.size(); + + return true; } // The thread id format is XXXX.YYYY.ZZZZ thread_local std::string localThreadID; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h index d261df72e25..fc4f9bbde14 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "pipeline/jit/static_analysis/static_analysis.h" @@ -38,82 +39,87 @@ class AsyncAbstract; using AsyncAbstractPtr = std::shared_ptr; class AnalysisSchedule { public: - ~AnalysisSchedule() = default; + ~AnalysisSchedule() { Stop(); } AnalysisSchedule(const AnalysisSchedule &) = delete; AnalysisSchedule &operator=(const AnalysisSchedule &) = delete; static AnalysisSchedule &GetInstance() { return instance_; } static void SetThreadID(const std::string &caller); static std::string &GetThreadID(); void HandleException(const std::exception &ex); - std::string GetExtendException() { return exceptionStream_.str(); } + void Stop() { notExit_ = false; } void Wait(); void Reset() { - activeThreadCount_ = 1; - threadNum_ = 0; - exceptionStream_.clear(); - } - - void SetNextRunnable() { - std::lock_guard lock(lock_); - SetNextRunnableImpl(); - } - - void Check() { - MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_; - if (activeThreadCount_ == 0) { - SetNextRunnableImpl(); - } else if (activeThreadCount_ < 0) { - MS_LOG(ERROR) << "There is something wrong. active thread count: " << activeThreadCount_; - } + active_thread_count_.store(1); + infer_thread_count_.store(0); } void EnterWaiting() { - std::lock_guard lock(lock_); - --activeThreadCount_; - MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_; - Check(); + { + std::lock_guard activeLock(activate_thread_lock_); + active_thread_count_.fetch_sub(1); + MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_; + } + activate_thread_cv_.notify_one(); } void LeaveWaiting() { - std::lock_guard lock(lock_); - ++activeThreadCount_; - MS_LOG(DEBUG) << this << " The active thread count: " << activeThreadCount_; + { + std::lock_guard activeLock(activate_thread_lock_); + active_thread_count_.fetch_add(1); + MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_; + } + activate_thread_cv_.notify_one(); } void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) { - std::lock_guard lock(lock_); - asyncAbstractList_.push_back(asyncAbastract); + std::lock_guard lock(activate_thread_lock_); + MS_LOG(DEBUG) << " push async:" << asyncAbastract.get() << " schedule list size:" << scheduleList_.size(); + scheduleList_.push_back(asyncAbastract); } + void IncreaseThreadCount() { - std::lock_guard lock(lock_); - ++threadNum_; - ++activeThreadCount_; - MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_; + infer_thread_count_.fetch_add(1); + { + std::lock_guard activeLock(activate_thread_lock_); + active_thread_count_.fetch_add(1); + MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_; + } + activate_thread_cv_.notify_one(); } + void DecreaseThreadCount() { { - std::lock_guard threadNumLock(lock_); - --threadNum_; + std::lock_guard threadNumLock(infer_thread_lock_); + infer_thread_count_.fetch_sub(1); } - condition_var_.notify_one(); + infer_thread_cv_.notify_one(); - std::lock_guard activeLock(lock_); - --activeThreadCount_; - MS_LOG(DEBUG) << "The active thread count: " << activeThreadCount_; - Check(); + { + std::lock_guard activeLock(activate_thread_lock_); + active_thread_count_.fetch_sub(1); + MS_LOG(DEBUG) << "The active thread count: " << active_thread_count_; + } + activate_thread_cv_.notify_one(); } private: - void SetNextRunnableImpl(); - AnalysisSchedule() = default; + void Schedule(); + bool SetNextReady(); + void Start() { + auto thread = std::thread([this] { Schedule(); }); + thread.detach(); + } + AnalysisSchedule() { Start(); } static AnalysisSchedule instance_; - int activeThreadCount_{1}; - int threadNum_{0}; - std::mutex lock_; - std::condition_variable condition_var_; - std::list asyncAbstractList_; - std::ostringstream exceptionStream_; + std::atomic active_thread_count_{1}; + std::atomic infer_thread_count_{0}; + bool notExit_{true}; + std::mutex infer_thread_lock_; + std::condition_variable infer_thread_cv_; + std::mutex activate_thread_lock_; + std::condition_variable activate_thread_cv_; + std::list scheduleList_; }; template @@ -216,57 +222,79 @@ class AsyncAbstract : public std::enable_shared_from_this { ~AsyncAbstract() = default; // Wait AbstractBasePtr GetResult() { - StaticAnalysisException::Instance().CheckException(); + MS_LOG(DEBUG) << this << " begin GetResult."; + std::unique_lock lock(lock_); while (true) { - ++count_; - // The active thread count should be dropped if it can't run. It will be added when it can run. - MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0); - bool hasEnterWaiting = false; - if (!runnable_) { + // Enter waiting ,and let the other thread to run + MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0); + if (!ready_) { AnalysisSchedule::GetInstance().EnterWaiting(); - hasEnterWaiting = true; } - MS_LOG(DEBUG) << this << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0); - { - std::unique_lock lock(lock_); - condition_var_.wait(lock, [this] { return runnable_; }); - } - if (hasEnterWaiting) { - AnalysisSchedule::GetInstance().LeaveWaiting(); - } - MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0); - + condition_var_.wait(lock, [this] { return ready_; }); + ClearReady(); // Clear nomal ready flag + MS_LOG(DEBUG) << this << " can go: " << ready_ << " result: " << (result_ ? result_.get() : 0); + HandleEndLessLoopException(); StaticAnalysisException::Instance().CheckException(); - SetUnrunnable(); if (result_ != nullptr) { - MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0); + MS_LOG(DEBUG) << this << " Success to GetResult. Return result: " << (result_ ? result_.get() : 0); return result_; } - // Push to list + // wait for result until it is not null. + ++count_; AnalysisSchedule::GetInstance().Add2Schedule(shared_from_this()); - // Notify the next asyncAbastract to run. - AnalysisSchedule::GetInstance().SetNextRunnable(); - MS_LOG(DEBUG) << this << " SetNextRunnable " - << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0); + MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0) + << " Enter schedule list to wait."; } } - void SetRunnable() { - MS_LOG(DEBUG) << this << " Runnable."; + void SetReady() { + MS_LOG(DEBUG) << this << " want to set ready."; { std::lock_guard lock(lock_); - runnable_ = true; + ready_ = ready_ | 1; // Set the first bit = 1 + MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0); } condition_var_.notify_one(); } - void SetUnrunnable() { - std::lock_guard lock(lock_); - runnable_ = false; + + void SetException() { + MS_LOG(DEBUG) << this << " want to set ready."; + { + std::lock_guard lock(lock_); + ready_ = ready_ | 2; // Set the second bit = 1 + MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0); + } + condition_var_.notify_one(); + } + + void SetEndLessLoopException() { + MS_LOG(DEBUG) << this << " want to set ready."; + { + std::lock_guard lock(lock_); + ready_ = ready_ | 4; // Set the third bit = 1 + MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0); + } + condition_var_.notify_one(); + } + + void ClearReady() { + ready_ = ready_ & 6; // Set first bit = 0 + MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << (result_ ? result_.get() : 0); + } + + void HandleEndLessLoopException() { + // Get third bit + if (ready_ & 4) { + ready_ = ready_ & 3; // Set the third bit = 0 , Only trigger once. + MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code."; + } } int count() const { return count_; } - - bool HasResult() { return result_ != nullptr; } + bool HasResult() { + std::lock_guard lock(lock_); + return result_ != nullptr; + } // Not wait AbstractBasePtr TryGetResult() { std::lock_guard lock(lock_); @@ -287,7 +315,7 @@ class AsyncAbstract : public std::enable_shared_from_this { private: std::mutex lock_; std::condition_variable condition_var_; - bool runnable_{false}; + int ready_{0}; // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop int count_{0}; AbstractBasePtr result_{nullptr}; };