From 11750cd8691586f894c996cc4d69ca989187a599 Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 29 Dec 2020 20:46:13 +0800 Subject: [PATCH] optimize executor task run --- mindspore/ccsrc/backend/session/executor.cc | 137 ++++++++++++-------- mindspore/ccsrc/backend/session/executor.h | 7 + 2 files changed, 92 insertions(+), 52 deletions(-) diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 631e38bd966..ff46cf7ed00 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -157,7 +157,7 @@ void Executor::WorkerJoin() { // Avoid worker thread join itself which will cause deadlock if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) { { - std::unique_lock lock(task_mutex_); + std::lock_guard lock(task_mutex_); auto task = std::make_shared(); ready_tasks_.push(task); task_cond_var_.notify_all(); @@ -186,10 +186,11 @@ void Executor::WorkerLoop() { MsException::Instance().SetException(); } { - std::unique_lock lock(task_mutex_); + std::lock_guard lock(done_task_mutex_); done_tasks_.emplace_back(task); } if (task->type_ != kRunGraph || task->sync_run_) { + sync_run_task_finished_ = true; sync_cond_var_.notify_all(); } } @@ -197,7 +198,7 @@ void Executor::WorkerLoop() { std::vector> Executor::GetNewReadyTasks() { std::vector> new_ready_tasks; - std::unique_lock lock(pending_task_mutex_); + std::lock_guard lock(pending_task_mutex_); for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { auto task = *iter; if (IsTaskReady(task)) { @@ -216,26 +217,35 @@ void Executor::OnEvent(const ExecutorEvent &event) { } else if (event == ExecutorEvent::kClear) { WorkerJoin(); } else if (event == ExecutorEvent::kException) { - { - std::unique_lock lock(task_mutex_); - while (!ready_tasks_.empty()) { - done_tasks_.emplace_back(ready_tasks_.front()); - ready_tasks_.pop(); - } + OnException(); + } +} + +void Executor::OnException() { + std::vector> new_done_tasks; + { + std::lock_guard lock(task_mutex_); + while (!ready_tasks_.empty()) { + new_done_tasks.emplace_back(ready_tasks_.front()); + ready_tasks_.pop(); } - { - std::unique_lock lock(pending_task_mutex_); - for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end(); iter++) { - done_tasks_.emplace_back(*iter); - } - pending_tasks_.clear(); + } + { + std::lock_guard lock(pending_task_mutex_); + for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end(); ++iter) { + new_done_tasks.emplace_back(*iter); } + pending_tasks_.clear(); + } + { + std::lock_guard lock(done_task_mutex_); + (void)done_tasks_.insert(done_tasks_.end(), new_done_tasks.begin(), new_done_tasks.end()); } } void Executor::OnRunGraphFinished() { auto new_ready_tasks = GetNewReadyTasks(); - std::unique_lock lock(task_mutex_); + std::lock_guard lock(task_mutex_); for (auto &task : new_ready_tasks) { ready_tasks_.push(task); } @@ -262,15 +272,31 @@ bool Executor::IsTaskReady(const std::shared_ptr &task) { return true; } -void Executor::SyncRunTask(const std::shared_ptr &task) { - std::unique_lock lock(task_mutex_); - ready_tasks_.push(task); +void Executor::ClearDoneTasks() { + std::lock_guard lock(done_task_mutex_); done_tasks_.clear(); +} + +void Executor::RunTask(const std::shared_ptr &task, bool sync) { + { + std::lock_guard lock(task_mutex_); + ready_tasks_.push(task); + } + sync_run_task_finished_ = false; task_cond_var_.notify_all(); - sync_cond_var_.wait(lock); + ClearDoneTasks(); + if (sync && !sync_run_task_finished_) { + std::unique_lock lock(task_mutex_); + sync_cond_var_.wait(lock, [this] { + bool finished = sync_run_task_finished_; + return finished; + }); + } MsException::Instance().CheckException(); } +void Executor::SyncRunTask(const std::shared_ptr &task) { RunTask(task, true); } + GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) { auto task = std::make_shared(); @@ -311,6 +337,41 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, SyncRunTask(task); } +void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr &task) { + bool need_lock = false; + for (auto &tensor : task->input_tensors_) { + if (tensor->NeedWait()) { + if (tensor->IsGraphOutput()) { + task->input_need_wait_tensors_.emplace_back(tensor); + } else { + need_lock = true; + } + } + } + if (need_lock) { + ClearDoneTasks(); + mindspore::ScopedLongRunning long_running; + for (auto &tensor : task->input_tensors_) { + if (tensor->NeedWait() && !tensor->IsGraphOutput()) { + tensor->Wait(); + } + } + MsException::Instance().CheckException(); + } + // need lock input parameters for optimizer + for (auto &tensor : task->input_need_lock_tensors_) { + tensor->SetNeedWait(true); + } + auto graph = session->GetGraph(task->graph_id_); + if (graph != nullptr && !graph->IsPostGraphFinished()) { + ClearDoneTasks(); + mindspore::ScopedLongRunning long_running; + std::unique_lock lock(reenter_mutex_); + reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); }); + MsException::Instance().CheckException(); + } +} + void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { MS_EXCEPTION_IF_NULL(session); @@ -320,24 +381,9 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, task->graph_id_ = graph_id; task->input_tensors_ = inputs; task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs); - for (auto &tensor : inputs) { - if (tensor->NeedWait()) { - if (tensor->IsGraphOutput()) { - task->input_need_wait_tensors_.emplace_back(tensor); - } else { - mindspore::ScopedLongRunning long_running; - tensor->Wait(); - } - } - } - MsException::Instance().CheckException(); - for (auto &tensor : task->input_need_lock_tensors_) { - tensor->SetNeedWait(true); - } session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); // maintain a copy of output vector task->outputs_ = *outputs; - // sync run graph without output tensor(int dataset graph) if (!TensorInVector(outputs)) { task->sync_run_ = true; @@ -345,26 +391,13 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, SyncRunTask(task); return; } - auto graph = session->GetGraph(task->graph_id_); - if (graph != nullptr) { - if (!graph->IsPostGraphFinished()) { - mindspore::ScopedLongRunning long_running; - std::unique_lock lock(reenter_mutex_); - reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); - MsException::Instance().CheckException(); - } - } - - bool ready = IsTaskReady(task); - if (!ready) { - std::unique_lock lock(pending_task_mutex_); + WaitTaskGraphAvailable(session, task); + if (!IsTaskReady(task)) { + std::lock_guard lock(pending_task_mutex_); pending_tasks_.push_back(task); return; } - std::unique_lock lock(task_mutex_); - ready_tasks_.push(task); - done_tasks_.clear(); - task_cond_var_.notify_all(); + RunTask(task, false); } void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index af501433d3a..8b1d0c150b6 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -171,18 +172,23 @@ class Executor { void OnEvent(const ExecutorEvent &event); private: + void RunTask(const std::shared_ptr &task, bool sync); void SyncRunTask(const std::shared_ptr &task); void UpdateOutputTensors(VectorRef *outputs, const std::map &tensor_to_node); std::vector> GetNewReadyTasks(); bool IsTaskReady(const std::shared_ptr &task); + void WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr &task); void CheckException(); void OnWorkerExit(); void OnRunGraphFinished(); + void OnException(); + void ClearDoneTasks(); uint32_t device_id_; std::string device_name_; std::mutex task_mutex_; + std::mutex done_task_mutex_; std::mutex pending_task_mutex_; std::mutex reenter_mutex_; std::condition_variable task_cond_var_; @@ -192,6 +198,7 @@ class Executor { std::list> pending_tasks_; std::vector> done_tasks_; std::shared_ptr worker_; + std::atomic_bool sync_run_task_finished_{false}; }; } // namespace session } // namespace mindspore