forked from mindspore-Ecosystem/mindspore
Allow join_all to be blocking
This commit is contained in:
parent
aeb4c52f2d
commit
a8fa847556
|
@ -22,6 +22,7 @@
|
|||
#include "mindspore/ccsrc/mindrecord/include/shard_error.h"
|
||||
#include "dataset/engine/gnn/local_edge.h"
|
||||
#include "dataset/engine/gnn/local_node.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
|
||||
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
|
||||
|
||||
|
@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() {
|
|||
n_feature_maps_.resize(num_workers_);
|
||||
e_feature_maps_.resize(num_workers_);
|
||||
default_feature_maps_.resize(num_workers_);
|
||||
std::vector<std::future<Status>> r_codes(num_workers_);
|
||||
TaskGroup vg;
|
||||
|
||||
shard_reader_ = std::make_unique<ShardReader>();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
|
||||
|
@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() {
|
|||
|
||||
// launching worker threads
|
||||
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
||||
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id);
|
||||
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
|
||||
}
|
||||
// wait for threads to finish and check its return code
|
||||
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
||||
RETURN_IF_NOT_OK(r_codes[wkr_id].get());
|
||||
}
|
||||
vg.join_all(Task::WaitFlag::kBlocking);
|
||||
RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u
|
|||
}
|
||||
|
||||
Status GraphLoader::WorkerEntry(int32_t worker_id) {
|
||||
// Handshake
|
||||
TaskManager::FindMe()->Post();
|
||||
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
|
||||
while (rows.empty() == false) {
|
||||
RETURN_IF_INTERRUPTED();
|
||||
for (const auto &tupled_row : rows) {
|
||||
std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
|
||||
mindrecord::json col_jsn = std::get<1>(tupled_row);
|
||||
|
|
|
@ -108,20 +108,27 @@ Status Task::Run() {
|
|||
return rc;
|
||||
}
|
||||
|
||||
Status Task::Join() {
|
||||
Status Task::Join(WaitFlag blocking) {
|
||||
if (running_) {
|
||||
RETURN_UNEXPECTED_IF_NULL(MyTaskGroup());
|
||||
auto interrupt_svc = MyTaskGroup()->GetIntrpService();
|
||||
try {
|
||||
// There is a race condition in the global resource tracking such that a thread can miss the
|
||||
// interrupt and becomes blocked on a conditional variable forever. As a result, calling
|
||||
// join() will not come back. We need some timeout version of join such that if the thread
|
||||
// doesn't come back in a reasonable of time, we will send the interrupt again.
|
||||
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
|
||||
// We can't tell which conditional_variable this thread is waiting on. So we may need
|
||||
// to interrupt everything one more time.
|
||||
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
|
||||
interrupt_svc->InterruptAll();
|
||||
if (blocking == WaitFlag::kBlocking) {
|
||||
// If we are asked to wait, then wait
|
||||
thrd_.get();
|
||||
} else if (blocking == WaitFlag::kNonBlocking) {
|
||||
// There is a race condition in the global resource tracking such that a thread can miss the
|
||||
// interrupt and becomes blocked on a conditional variable forever. As a result, calling
|
||||
// join() will not come back. We need some timeout version of join such that if the thread
|
||||
// doesn't come back in a reasonable of time, we will send the interrupt again.
|
||||
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
|
||||
// We can't tell which conditional_variable this thread is waiting on. So we may need
|
||||
// to interrupt everything one more time.
|
||||
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
|
||||
interrupt_svc->InterruptAll();
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Unknown WaitFlag");
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << get_id();
|
||||
|
|
|
@ -42,9 +42,10 @@ class TaskManager;
|
|||
class Task : public IntrpResource {
|
||||
public:
|
||||
friend class TaskManager;
|
||||
|
||||
friend class TaskGroup;
|
||||
|
||||
enum class WaitFlag : int { kBlocking, kNonBlocking };
|
||||
|
||||
Task(const std::string &myName, const std::function<Status()> &f);
|
||||
|
||||
// Future objects are not copyable.
|
||||
|
@ -74,7 +75,7 @@ class Task : public IntrpResource {
|
|||
// Run the task
|
||||
Status Run();
|
||||
|
||||
Status Join();
|
||||
Status Join(WaitFlag wf = WaitFlag::kBlocking);
|
||||
|
||||
bool Running() const { return running_; }
|
||||
|
||||
|
|
|
@ -278,12 +278,12 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
|
|||
|
||||
void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }
|
||||
|
||||
Status TaskGroup::join_all() {
|
||||
Status TaskGroup::join_all(Task::WaitFlag wf) {
|
||||
Status rc;
|
||||
Status rc2;
|
||||
SharedLock lck(&rw_lock_);
|
||||
for (Task &tk : grp_list_) {
|
||||
rc = tk.Join();
|
||||
rc = tk.Join(wf);
|
||||
if (rc.IsError()) {
|
||||
rc2 = rc;
|
||||
}
|
||||
|
@ -294,7 +294,7 @@ Status TaskGroup::join_all() {
|
|||
Status TaskGroup::DoServiceStop() {
|
||||
intrp_svc_->ServiceStop();
|
||||
interrupt_all();
|
||||
return (join_all());
|
||||
return (join_all(Task::WaitFlag::kNonBlocking));
|
||||
}
|
||||
|
||||
TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
|
||||
|
|
|
@ -122,7 +122,7 @@ class TaskGroup : public Service {
|
|||
|
||||
void interrupt_all() noexcept;
|
||||
|
||||
Status join_all();
|
||||
Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking);
|
||||
|
||||
int size() const noexcept { return grp_list_.count; }
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ TEST_F(MindDataTestIntrpService, Test1) {
|
|||
return rc;
|
||||
});
|
||||
vg_.GetIntrpService()->InterruptAll();
|
||||
vg_.join_all();
|
||||
vg_.join_all(Task::WaitFlag::kNonBlocking);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestIntrpService, Test2) {
|
||||
|
@ -64,5 +64,5 @@ TEST_F(MindDataTestIntrpService, Test2) {
|
|||
return rc;
|
||||
});
|
||||
vg_.GetIntrpService()->InterruptAll();
|
||||
vg_.join_all();
|
||||
}
|
||||
vg_.join_all(Task::WaitFlag::kNonBlocking);
|
||||
}
|
|
@ -80,5 +80,5 @@ TEST_F(MindDataTestTaskManager, Test2) {
|
|||
vg.interrupt_all();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// Now we test the async Join
|
||||
ASSERT_TRUE(vg.join_all().IsOk());
|
||||
ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue