From 2c42665e909abdc4c5dc897c1b65171adb9085e0 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Fri, 29 May 2020 07:14:06 +0800 Subject: [PATCH] fix lenet hang problem on windows --- mindspore/ccsrc/dataset/util/task.cc | 2 ++ mindspore/ccsrc/dataset/util/task_manager.cc | 24 ++++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/dataset/util/task.cc b/mindspore/ccsrc/dataset/util/task.cc index 0d02ad8317b..cb146a652fa 100644 --- a/mindspore/ccsrc/dataset/util/task.cc +++ b/mindspore/ccsrc/dataset/util/task.cc @@ -24,7 +24,9 @@ namespace dataset { thread_local Task *gMyTask = nullptr; void Task::operator()() { +#if !defined(_WIN32) && !defined(_WIN64) gMyTask = this; +#endif id_ = this_thread::get_id(); std::stringstream ss; ss << id_; diff --git a/mindspore/ccsrc/dataset/util/task_manager.cc b/mindspore/ccsrc/dataset/util/task_manager.cc index 31b0eedd704..0483364863b 100644 --- a/mindspore/ccsrc/dataset/util/task_manager.cc +++ b/mindspore/ccsrc/dataset/util/task_manager.cc @@ -87,7 +87,27 @@ void TaskManager::interrupt_all() noexcept { (void)master_->Interrupt(); } -Task *TaskManager::FindMe() { return gMyTask; } +Task *TaskManager::FindMe() { +#if !defined(_WIN32) && !defined(_WIN64) + return gMyTask; +#else + TaskManager &tm = TaskManager::GetInstance(); + SharedLock lock(&tm.lru_lock_); + auto id = this_thread::get_id(); + auto tk = std::find_if(tm.lru_.begin(), tm.lru_.end(), [id](const Task &tk) { return tk.id_ == id; }); + if (tk != tm.lru_.end()) { + return &(*tk); + } + // If we get here, either I am the watchdog or the master thread. + if (tm.master_->id_ == id) { + return tm.master_.get(); + } else if (tm.watchdog_ != nullptr && tm.watchdog_->id_ == id) { + return tm.watchdog_; + } + MS_LOG(ERROR) << "Task not found."; + return nullptr; +#endif +} TaskManager::TaskManager() try : global_interrupt_(0), lru_(&Task::node), @@ -101,8 +121,8 @@ TaskManager::TaskManager() try : global_interrupt_(0), master_->id_ = this_thread::get_id(); master_->running_ = true; master_->is_master_ = true; - gMyTask = master_.get(); #if !defined(_WIN32) && !defined(_WIN64) + gMyTask = master_.get(); // Initialize the semaphore for the watchdog errno_t rc = sem_init(&sem_, 0, 0); if (rc == -1) {