forked from mindspore-Ecosystem/mindspore
!1406 Simplify CondVar class
Merge pull request !1406 from JesseKLee/CondVar
This commit is contained in:
commit
e8980ed298
|
@ -14,35 +14,34 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/util/cond_var.h"
|
||||
#include <exception>
|
||||
#include <utility>
|
||||
#include "dataset/util/services.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {}
|
||||
CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {}
|
||||
|
||||
Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) {
|
||||
// Append an additional condition on top of the given predicate.
|
||||
// We will also bail out if this cv got interrupted.
|
||||
auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); };
|
||||
// If we have interrupt service, just wait on the cv unconditionally.
|
||||
// Otherwise fall back to the old way of checking interrupt.
|
||||
if (svc_) {
|
||||
cv_.wait(*lck, f);
|
||||
if (CurState() == State::kInterrupted) {
|
||||
Task *my_task = TaskManager::FindMe();
|
||||
if (my_task->IsMasterThread() && my_task->CaughtSevereException()) {
|
||||
return TaskManager::GetMasterThreadRc();
|
||||
} else {
|
||||
return Status(StatusCode::kInterrupted);
|
||||
try {
|
||||
if (svc_ != nullptr) {
|
||||
// If this cv registers with a global resource tracking, then wait unconditionally.
|
||||
auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); };
|
||||
cv_.wait(*lck, f);
|
||||
// If we are interrupted, override the return value if this is the master thread.
|
||||
// Master thread is being interrupted mostly because of some thread is reporting error.
|
||||
RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus()));
|
||||
} else {
|
||||
// Otherwise we wake up once a while to check for interrupt (for this thread).
|
||||
auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); };
|
||||
while (!f()) {
|
||||
(void)cv_.wait_for(*lck, std::chrono::milliseconds(1));
|
||||
}
|
||||
RETURN_IF_INTERRUPTED();
|
||||
}
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred));
|
||||
if (CurState() == State::kInterrupted) {
|
||||
return Status(StatusCode::kInterrupted);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) {
|
|||
return rc;
|
||||
}
|
||||
|
||||
Status CondVar::Interrupt() {
|
||||
RETURN_IF_NOT_OK(IntrpResource::Interrupt());
|
||||
void CondVar::Interrupt() {
|
||||
IntrpResource::Interrupt();
|
||||
cv_.notify_all();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string CondVar::my_name() const { return my_name_; }
|
||||
|
|
|
@ -35,7 +35,7 @@ class CondVar : public IntrpResource {
|
|||
|
||||
Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred);
|
||||
|
||||
Status Interrupt() override;
|
||||
void Interrupt() override;
|
||||
|
||||
void NotifyOne() noexcept;
|
||||
|
||||
|
|
|
@ -29,10 +29,7 @@ class IntrpResource {
|
|||
|
||||
virtual ~IntrpResource() = default;
|
||||
|
||||
virtual Status Interrupt() {
|
||||
st_ = State::kInterrupted;
|
||||
return Status::OK();
|
||||
}
|
||||
virtual void Interrupt() { st_ = State::kInterrupted; }
|
||||
|
||||
virtual void ResetIntrpState() { st_ = State::kRunning; }
|
||||
|
||||
|
@ -40,6 +37,13 @@ class IntrpResource {
|
|||
|
||||
bool Interrupted() const { return CurState() == State::kInterrupted; }
|
||||
|
||||
virtual Status GetInterruptStatus() const {
|
||||
if (Interrupted()) {
|
||||
return Status(StatusCode::kInterrupted);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::atomic<State> st_;
|
||||
};
|
||||
|
|
|
@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept {
|
|||
MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << ".";
|
||||
if (!all_intrp_resources_.empty()) {
|
||||
try {
|
||||
(void)InterruptAll();
|
||||
InterruptAll();
|
||||
} catch (const std::exception &e) {
|
||||
// Ignore all error as we can't throw in the destructor.
|
||||
}
|
||||
|
@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
|
|||
std::ostringstream ss;
|
||||
ss << this_thread::get_id();
|
||||
MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << ".";
|
||||
auto it = all_intrp_resources_.find(name);
|
||||
if (it != all_intrp_resources_.end()) {
|
||||
(void)all_intrp_resources_.erase(it);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Key " << name << " not found.";
|
||||
auto n = all_intrp_resources_.erase(name);
|
||||
if (n == 0) {
|
||||
MS_LOG(INFO) << "Key " << name << " not found.";
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
|
@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IntrpService::InterruptAll() noexcept {
|
||||
void IntrpService::InterruptAll() noexcept {
|
||||
std::lock_guard<std::mutex> lck(mutex_);
|
||||
Status rc;
|
||||
for (auto const &it : all_intrp_resources_) {
|
||||
std::string kName = it.first;
|
||||
try {
|
||||
Status rc2 = it.second->Interrupt();
|
||||
if (rc2.IsError()) {
|
||||
rc = rc2;
|
||||
}
|
||||
it.second->Interrupt();
|
||||
} catch (const std::exception &e) {
|
||||
// continue the clean up.
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,7 @@ class IntrpService : public Service {
|
|||
|
||||
Status Deregister(const std::string &name) noexcept;
|
||||
|
||||
Status InterruptAll() noexcept;
|
||||
void InterruptAll() noexcept;
|
||||
|
||||
Status DoServiceStart() override { return Status::OK(); }
|
||||
|
||||
|
|
|
@ -110,7 +110,7 @@ class Queue {
|
|||
empty_cv_.NotifyAll();
|
||||
_lock.unlock();
|
||||
} else {
|
||||
(void)empty_cv_.Interrupt();
|
||||
empty_cv_.Interrupt();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ class Queue {
|
|||
empty_cv_.NotifyAll();
|
||||
_lock.unlock();
|
||||
} else {
|
||||
(void)empty_cv_.Interrupt();
|
||||
empty_cv_.Interrupt();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ class Queue {
|
|||
empty_cv_.NotifyAll();
|
||||
_lock.unlock();
|
||||
} else {
|
||||
(void)empty_cv_.Interrupt();
|
||||
empty_cv_.Interrupt();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ class Queue {
|
|||
full_cv_.NotifyAll();
|
||||
_lock.unlock();
|
||||
} else {
|
||||
(void)full_cv_.Interrupt();
|
||||
full_cv_.Interrupt();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <mutex>
|
||||
#include <string>
|
||||
#include "dataset/util/memory_pool.h"
|
||||
#include "dataset/util/allocator.h"
|
||||
#include "dataset/util/service.h"
|
||||
|
||||
#define UNIQUEID_LEN 36
|
||||
|
@ -72,6 +73,11 @@ class Services {
|
|||
|
||||
static std::string GetUniqueID();
|
||||
|
||||
template <typename T>
|
||||
static Allocator<T> GetAllocator() {
|
||||
return Allocator<T>(Services::GetInstance().GetServiceMemPool());
|
||||
}
|
||||
|
||||
private:
|
||||
static std::once_flag init_instance_flag_;
|
||||
static std::unique_ptr<Services> instance_;
|
||||
|
|
|
@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine.
|
|||
}
|
||||
}
|
||||
|
||||
Status Task::GetTaskErrorIfAny() {
|
||||
Status Task::GetTaskErrorIfAny() const {
|
||||
std::lock_guard<std::mutex> lk(mux_);
|
||||
if (caught_severe_exception_) {
|
||||
return rc_;
|
||||
|
@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; }
|
|||
void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; }
|
||||
|
||||
Task::~Task() { task_group_ = nullptr; }
|
||||
Status Task::OverrideInterruptRc(const Status &rc) {
|
||||
if (rc.IsInterrupted() && this_thread::is_master_thread()) {
|
||||
// If we are interrupted, override the return value if this is the master thread.
|
||||
// Master thread is being interrupted mostly because of some thread is reporting error.
|
||||
return TaskManager::GetMasterThreadRc();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,7 +60,7 @@ class Task : public IntrpResource {
|
|||
|
||||
Task &operator=(Task &&) = delete;
|
||||
|
||||
Status GetTaskErrorIfAny();
|
||||
Status GetTaskErrorIfAny() const;
|
||||
|
||||
void ChangeName(const std::string &newName) { my_name_ = newName; }
|
||||
|
||||
|
@ -95,10 +95,10 @@ class Task : public IntrpResource {
|
|||
|
||||
Status Wait() { return (wp_.Wait()); }
|
||||
|
||||
void set_task_group(TaskGroup *vg);
|
||||
static Status OverrideInterruptRc(const Status &rc);
|
||||
|
||||
private:
|
||||
std::mutex mux_;
|
||||
mutable std::mutex mux_;
|
||||
std::string my_name_;
|
||||
Status rc_;
|
||||
WaitPost wp_;
|
||||
|
@ -115,6 +115,7 @@ class Task : public IntrpResource {
|
|||
|
||||
void ShutdownGroup();
|
||||
TaskGroup *MyTaskGroup();
|
||||
void set_task_group(TaskGroup *vg);
|
||||
};
|
||||
|
||||
extern thread_local Task *gMyTask;
|
||||
|
|
|
@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept {
|
|||
svc->InterruptAll();
|
||||
}
|
||||
}
|
||||
(void)master_->Interrupt();
|
||||
master_->Interrupt();
|
||||
}
|
||||
|
||||
Task *TaskManager::FindMe() { return gMyTask; }
|
||||
|
@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0),
|
|||
free_lst_(&Task::free),
|
||||
watchdog_grp_(nullptr),
|
||||
watchdog_(nullptr) {
|
||||
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool();
|
||||
Allocator<Task> alloc(mp);
|
||||
auto alloc = Services::GetAllocator<Task>();
|
||||
// Create a dummy Task for the master thread (this thread)
|
||||
master_ = std::allocate_shared<Task>(alloc, "master", []() -> Status { return Status::OK(); });
|
||||
master_->id_ = this_thread::get_id();
|
||||
|
@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) {
|
|||
TaskManager &tm = TaskManager::GetInstance();
|
||||
std::shared_ptr<Task> master = tm.master_;
|
||||
std::lock_guard<std::mutex> lck(master->mux_);
|
||||
(void)master->Interrupt();
|
||||
master->Interrupt();
|
||||
if (rc.IsError() && master->rc_.IsOk()) {
|
||||
master->rc_ = rc;
|
||||
master->caught_severe_exception_ = true;
|
||||
|
@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void TaskGroup::interrupt_all() noexcept { (void)intrp_svc_->InterruptAll(); }
|
||||
void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }
|
||||
|
||||
Status TaskGroup::join_all() {
|
||||
Status rc;
|
||||
|
@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() {
|
|||
}
|
||||
|
||||
TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
|
||||
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool();
|
||||
Allocator<IntrpService> alloc(mp);
|
||||
auto alloc = Services::GetAllocator<IntrpService>();
|
||||
intrp_svc_ = std::allocate_shared<IntrpService>(alloc);
|
||||
(void)Service::ServiceStart();
|
||||
}
|
||||
|
|
|
@ -154,37 +154,27 @@ inline bool is_interrupted() {
|
|||
return true;
|
||||
}
|
||||
Task *my_task = TaskManager::FindMe();
|
||||
return (my_task != nullptr) ? my_task->Interrupted() : false;
|
||||
return my_task->Interrupted();
|
||||
}
|
||||
|
||||
inline bool is_master_thread() {
|
||||
Task *my_task = TaskManager::FindMe();
|
||||
return my_task->IsMasterThread();
|
||||
}
|
||||
|
||||
inline Status GetInterruptStatus() {
|
||||
Task *my_task = TaskManager::FindMe();
|
||||
return my_task->GetInterruptStatus();
|
||||
}
|
||||
} // namespace this_thread
|
||||
|
||||
#define RETURN_IF_INTERRUPTED() \
|
||||
do { \
|
||||
if (mindspore::dataset::this_thread::is_interrupted()) { \
|
||||
Task *myTask = TaskManager::FindMe(); \
|
||||
if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \
|
||||
return TaskManager::GetMasterThreadRc(); \
|
||||
} else { \
|
||||
return Status(StatusCode::kInterrupted); \
|
||||
} \
|
||||
} \
|
||||
#define RETURN_IF_INTERRUPTED() \
|
||||
do { \
|
||||
if (mindspore::dataset::this_thread::is_interrupted()) { \
|
||||
return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
inline Status interruptible_wait(std::condition_variable *cv, std::unique_lock<std::mutex> *lk,
|
||||
const std::function<bool()> &pred) noexcept {
|
||||
if (!pred()) {
|
||||
do {
|
||||
RETURN_IF_INTERRUPTED();
|
||||
try {
|
||||
(void)cv->wait_for(*lk, std::chrono::milliseconds(1));
|
||||
} catch (std::exception &e) {
|
||||
// Anything thrown by wait_for is considered system error.
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
} while (!pred());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() {
|
|||
10); // capacity of each queue
|
||||
DS_ASSERT(my_conn != nullptr);
|
||||
|
||||
rc = my_conn->Register(tg_.get());
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
||||
// Spawn a thread to read input_ vector and put it in my_conn
|
||||
rc = tg_->CreateAsyncTask("Worker Push",
|
||||
std::bind(&MindDataTestConnector::FirstWorkerPush,
|
||||
|
@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() {
|
|||
l3_threads,
|
||||
conn2_qcap);
|
||||
|
||||
rc = conn1->Register(tg_.get());
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
rc = conn2->Register(tg_.get());
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
||||
// Instantiating the threads in the first layer
|
||||
for (int i = 0; i < l1_threads; i++) {
|
||||
rc = tg_->CreateAsyncTask("First Worker Push",
|
||||
|
|
Loading…
Reference in New Issue