forked from mindspore-Ecosystem/mindspore
!25963 [MD][Autotune] Add WaitFor
Merge pull request !25963 from harshvardhangupta/hesh_tree_mod
This commit is contained in:
commit
2e174427c9
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -45,6 +45,30 @@ Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CondVar::WaitFor(std::unique_lock<std::mutex> *lck, int64_t duration) {
|
||||||
|
try {
|
||||||
|
if (svc_ != nullptr) {
|
||||||
|
// If this cv registers with a global resource tracking, then wait unconditionally.
|
||||||
|
auto f = [this]() -> bool { return this->Interrupted(); };
|
||||||
|
cv_.wait_for(*lck, std::chrono::milliseconds(duration), f);
|
||||||
|
// If we are interrupted, override the return value if this is the master thread.
|
||||||
|
// Master thread is being interrupted mostly because of some thread is reporting error.
|
||||||
|
RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus()));
|
||||||
|
} else {
|
||||||
|
// Otherwise we wake up once a while to check for interrupt (for this thread).
|
||||||
|
auto f = []() -> bool { return this_thread::is_interrupted(); };
|
||||||
|
int64_t ctr = 0;
|
||||||
|
while (!f() && ctr++ < duration) {
|
||||||
|
(void)cv_.wait_for(*lck, std::chrono::milliseconds(1), f);
|
||||||
|
}
|
||||||
|
RETURN_IF_INTERRUPTED();
|
||||||
|
}
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
RETURN_STATUS_UNEXPECTED(e.what());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
CondVar::~CondVar() noexcept {
|
CondVar::~CondVar() noexcept {
|
||||||
if (svc_ != nullptr) {
|
if (svc_ != nullptr) {
|
||||||
(void)svc_->Deregister(my_name_);
|
(void)svc_->Deregister(my_name_);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -35,6 +35,12 @@ class CondVar : public IntrpResource {
|
||||||
|
|
||||||
Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred);
|
Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred);
|
||||||
|
|
||||||
|
/// Timed sleep.
|
||||||
|
/// \param lck lock
|
||||||
|
/// \param duration time to sleep in ms
|
||||||
|
/// \return Status code
|
||||||
|
Status WaitFor(std::unique_lock<std::mutex> *lck, int64_t duration);
|
||||||
|
|
||||||
void Interrupt() override;
|
void Interrupt() override;
|
||||||
|
|
||||||
void NotifyOne() noexcept;
|
void NotifyOne() noexcept;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -83,3 +83,33 @@ TEST_F(MindDataTestTaskManager, Test2) {
|
||||||
// Now we test the async Join
|
// Now we test the async Join
|
||||||
ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk());
|
ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: WaitFor in CondVar.
|
||||||
|
/// Description: test WaitFor function
|
||||||
|
/// Expectation: no hangs or failures
|
||||||
|
TEST_F(MindDataTestTaskManager, Test3) {
|
||||||
|
(void)TaskManager::GetMasterThreadRc();
|
||||||
|
TaskGroup vg;
|
||||||
|
CondVar cv;
|
||||||
|
std::mutex mux;
|
||||||
|
Status rc;
|
||||||
|
rc = cv.Register(vg.GetIntrpService());
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
auto block_forever = [&cv, &mux]() -> Status {
|
||||||
|
std::unique_lock<std::mutex> lck(mux);
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||||
|
RETURN_IF_NOT_OK(cv.WaitFor(&lck, 1000 * 5));
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
auto f = [&vg, &block_forever]() -> Status {
|
||||||
|
RETURN_IF_NOT_OK(vg.CreateAsyncTask("Spawn block threads", block_forever));
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
rc = f();
|
||||||
|
|
||||||
|
vg.interrupt_all();
|
||||||
|
ASSERT_OK(rc);
|
||||||
|
// Now we test the async Join
|
||||||
|
ASSERT_OK(vg.join_all(Task::WaitFlag::kNonBlocking));
|
||||||
|
}
|
Loading…
Reference in New Issue