forked from mindspore-Ecosystem/mindspore
optimize executor task run
This commit is contained in:
parent
c85590c9f5
commit
11750cd869
|
@ -157,7 +157,7 @@ void Executor::WorkerJoin() {
|
|||
// Avoid worker thread join itself which will cause deadlock
|
||||
if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
auto task = std::make_shared<ExitTask>();
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
|
@ -186,10 +186,11 @@ void Executor::WorkerLoop() {
|
|||
MsException::Instance().SetException();
|
||||
}
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
done_tasks_.emplace_back(task);
|
||||
}
|
||||
if (task->type_ != kRunGraph || task->sync_run_) {
|
||||
sync_run_task_finished_ = true;
|
||||
sync_cond_var_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -197,7 +198,7 @@ void Executor::WorkerLoop() {
|
|||
|
||||
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
|
||||
std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
|
||||
std::unique_lock<std::mutex> lock(pending_task_mutex_);
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
|
||||
auto task = *iter;
|
||||
if (IsTaskReady(task)) {
|
||||
|
@ -216,26 +217,35 @@ void Executor::OnEvent(const ExecutorEvent &event) {
|
|||
} else if (event == ExecutorEvent::kClear) {
|
||||
WorkerJoin();
|
||||
} else if (event == ExecutorEvent::kException) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
while (!ready_tasks_.empty()) {
|
||||
done_tasks_.emplace_back(ready_tasks_.front());
|
||||
ready_tasks_.pop();
|
||||
}
|
||||
OnException();
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::OnException() {
|
||||
std::vector<std::shared_ptr<Task>> new_done_tasks;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
while (!ready_tasks_.empty()) {
|
||||
new_done_tasks.emplace_back(ready_tasks_.front());
|
||||
ready_tasks_.pop();
|
||||
}
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(pending_task_mutex_);
|
||||
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end(); iter++) {
|
||||
done_tasks_.emplace_back(*iter);
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end(); ++iter) {
|
||||
new_done_tasks.emplace_back(*iter);
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
(void)done_tasks_.insert(done_tasks_.end(), new_done_tasks.begin(), new_done_tasks.end());
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::OnRunGraphFinished() {
|
||||
auto new_ready_tasks = GetNewReadyTasks();
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
for (auto &task : new_ready_tasks) {
|
||||
ready_tasks_.push(task);
|
||||
}
|
||||
|
@ -262,15 +272,31 @@ bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
ready_tasks_.push(task);
|
||||
void Executor::ClearDoneTasks() {
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
done_tasks_.clear();
|
||||
}
|
||||
|
||||
void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
ready_tasks_.push(task);
|
||||
}
|
||||
sync_run_task_finished_ = false;
|
||||
task_cond_var_.notify_all();
|
||||
sync_cond_var_.wait(lock);
|
||||
ClearDoneTasks();
|
||||
if (sync && !sync_run_task_finished_) {
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
sync_cond_var_.wait(lock, [this] {
|
||||
bool finished = sync_run_task_finished_;
|
||||
return finished;
|
||||
});
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
|
||||
void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { RunTask(task, true); }
|
||||
|
||||
GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
|
||||
const AnfNodePtrList &outputs) {
|
||||
auto task = std::make_shared<CompileNodesTask>();
|
||||
|
@ -311,6 +337,41 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
|
|||
SyncRunTask(task);
|
||||
}
|
||||
|
||||
void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
|
||||
bool need_lock = false;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait()) {
|
||||
if (tensor->IsGraphOutput()) {
|
||||
task->input_need_wait_tensors_.emplace_back(tensor);
|
||||
} else {
|
||||
need_lock = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (need_lock) {
|
||||
ClearDoneTasks();
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
|
||||
tensor->Wait();
|
||||
}
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
// need lock input parameters for optimizer
|
||||
for (auto &tensor : task->input_need_lock_tensors_) {
|
||||
tensor->SetNeedWait(true);
|
||||
}
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr && !graph->IsPostGraphFinished()) {
|
||||
ClearDoneTasks();
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
std::unique_lock<std::mutex> lock(reenter_mutex_);
|
||||
reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
|
@ -320,24 +381,9 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
|||
task->graph_id_ = graph_id;
|
||||
task->input_tensors_ = inputs;
|
||||
task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
|
||||
for (auto &tensor : inputs) {
|
||||
if (tensor->NeedWait()) {
|
||||
if (tensor->IsGraphOutput()) {
|
||||
task->input_need_wait_tensors_.emplace_back(tensor);
|
||||
} else {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
tensor->Wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
for (auto &tensor : task->input_need_lock_tensors_) {
|
||||
tensor->SetNeedWait(true);
|
||||
}
|
||||
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
|
||||
// maintain a copy of output vector
|
||||
task->outputs_ = *outputs;
|
||||
|
||||
// sync run graph without output tensor(int dataset graph)
|
||||
if (!TensorInVector(outputs)) {
|
||||
task->sync_run_ = true;
|
||||
|
@ -345,26 +391,13 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
|||
SyncRunTask(task);
|
||||
return;
|
||||
}
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr) {
|
||||
if (!graph->IsPostGraphFinished()) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
std::unique_lock<std::mutex> lock(reenter_mutex_);
|
||||
reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
}
|
||||
|
||||
bool ready = IsTaskReady(task);
|
||||
if (!ready) {
|
||||
std::unique_lock<std::mutex> lock(pending_task_mutex_);
|
||||
WaitTaskGraphAvailable(session, task);
|
||||
if (!IsTaskReady(task)) {
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
pending_tasks_.push_back(task);
|
||||
return;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
ready_tasks_.push(task);
|
||||
done_tasks_.clear();
|
||||
task_cond_var_.notify_all();
|
||||
RunTask(task, false);
|
||||
}
|
||||
|
||||
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
@ -171,18 +172,23 @@ class Executor {
|
|||
void OnEvent(const ExecutorEvent &event);
|
||||
|
||||
private:
|
||||
void RunTask(const std::shared_ptr<Task> &task, bool sync);
|
||||
void SyncRunTask(const std::shared_ptr<Task> &task);
|
||||
void UpdateOutputTensors(VectorRef *outputs,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
|
||||
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
|
||||
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
|
||||
void WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task);
|
||||
void CheckException();
|
||||
void OnWorkerExit();
|
||||
void OnRunGraphFinished();
|
||||
void OnException();
|
||||
void ClearDoneTasks();
|
||||
|
||||
uint32_t device_id_;
|
||||
std::string device_name_;
|
||||
std::mutex task_mutex_;
|
||||
std::mutex done_task_mutex_;
|
||||
std::mutex pending_task_mutex_;
|
||||
std::mutex reenter_mutex_;
|
||||
std::condition_variable task_cond_var_;
|
||||
|
@ -192,6 +198,7 @@ class Executor {
|
|||
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
|
||||
std::vector<std::shared_ptr<Task>> done_tasks_;
|
||||
std::shared_ptr<std::thread> worker_;
|
||||
std::atomic_bool sync_run_task_finished_{false};
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue