Callback changes

AddWorker logic to parallel ops
This commit is contained in:
hesham 2021-10-09 10:28:14 -04:00
parent 5e1dab83b8
commit 85e077b60b
11 changed files with 196 additions and 115 deletions

View File

@ -24,6 +24,15 @@ namespace dataset {
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
callbacks.push_back(callbacks[ind]);
if (callbacks[ind]->IsBeginNeeded()) begin_indices_.push_back(ind);
if (callbacks[ind]->IsEndNeeded()) end_indices_.push_back(ind);
if (callbacks[ind]->IsEpochBeginNeeded()) epoch_begin_indices_.push_back(ind);
if (callbacks[ind]->IsEpochEndNeeded()) epoch_end_indices_.push_back(ind);
if (callbacks[ind]->IsNStepBeginNeeded()) step_begin_indices_.push_back(ind);
if (callbacks[ind]->IsNStepEndNeeded()) step_end_indices_.push_back(ind);
}
}
Status CallbackManager::Init(DatasetOp *op) {
@ -43,18 +52,9 @@ Status CallbackManager::Init(DatasetOp *op) {
Status CallbackManager::Begin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
RETURN_UNEXPECTED_IF_NULL(op_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
for (size_t ind : begin_indices_) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
}
return Status::OK();
@ -63,18 +63,11 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) {
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
RETURN_UNEXPECTED_IF_NULL(op_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no epoch_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
for (size_t ind : epoch_begin_indices_) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
}
return Status::OK();
@ -83,20 +76,11 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
RETURN_UNEXPECTED_IF_NULL(op_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
callback_inds.push_back(ind);
}
// return Status::OK() if no step_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
for (size_t ind : step_begin_indices_) {
if ((cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
}
return Status::OK();
}
@ -104,18 +88,11 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
Status CallbackManager::End(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
RETURN_UNEXPECTED_IF_NULL(op_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
RETURN_OK_IF_TRUE(end_indices_.empty());
// Now do the actual callback
for (size_t ind : callback_inds) {
for (size_t ind : end_indices_) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
}
return Status::OK();
@ -124,18 +101,9 @@ Status CallbackManager::End(const CallbackParam &cb_param) {
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
RETURN_UNEXPECTED_IF_NULL(op_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no epoch_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
for (size_t ind : epoch_end_indices_) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
}
return Status::OK();
@ -144,20 +112,11 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
RETURN_UNEXPECTED_IF_NULL(op_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
callback_inds.push_back(ind);
}
// return Status::OK() if no step_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
for (size_t ind : step_end_indices_) {
if ((cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
}
return Status::OK();
}

View File

@ -77,6 +77,12 @@ class CallbackManager {
bool enabled_; // flag to enable callback, if false, all functions would return immediately
DatasetOp *op_; // back pointer to DatasetOp, raw pointer to avoid circular ownership
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
std::vector<size_t> begin_indices_;
std::vector<size_t> end_indices_;
std::vector<size_t> epoch_begin_indices_;
std::vector<size_t> epoch_end_indices_;
std::vector<size_t> step_begin_indices_;
std::vector<size_t> step_end_indices_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -217,6 +217,8 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
RETURN_IF_NOT_OK(worker_out_queues_[workerId]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
} else if (table_pair.second.ctrl_ == batchCtrl::kEOF) {
RETURN_IF_NOT_OK(worker_out_queues_[workerId]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
} else if (table_pair.second.ctrl_ == batchCtrl::kWait) {
RETURN_IF_NOT_OK(worker_out_queues_[workerId]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagWait)));
} else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) {
TensorRow new_row;
RETURN_IF_NOT_OK(MakeBatchedRow(std::move(table_pair), &new_row));
@ -573,7 +575,15 @@ Status BatchOp::GetNextRowPullMode(TensorRow *const row) {
}
return Status::OK();
}
Status BatchOp::WaitForWorkers() { return Status::OK(); }
Status BatchOp::WaitForWorkers() {
num_workers_paused_ = 0;
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(worker_in_queues_[i]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kWait))));
}
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
wait_for_workers_post_.Clear();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -37,7 +37,7 @@ namespace dataset {
using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>;
enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 };
enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3, kWait = 4 };
// Parameters associate with one batch.
// This struct is used for both internal control and python callback.

View File

@ -112,6 +112,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// Getter function to get all of our parents.
std::vector<DatasetOp *> parents() const;
virtual Status AddNewWorkers(int32_t num_new_workers = 1) {
return Status(StatusCode::kMDUnexpectedError, "Add new workers is not supported for non-ParallelOps");
}
// \brief Inserts a operator as the parent current op.
// \notes Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.

View File

@ -116,8 +116,8 @@ Status MapOp::operator()() {
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
// num_rows received, including eoe, num_epoch, num_step of current epoch
int64_t num_rows = 0, ep_step = 0, total_step = 0;
int64_t ep_step = 0, total_step = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
@ -127,6 +127,7 @@ Status MapOp::operator()() {
while (!new_row.eof()) {
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
ep_step = 0;
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
while (!new_row.eoe()) {
@ -142,35 +143,27 @@ Status MapOp::operator()() {
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
// Push map worker job to the corresponding worker's queue
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
}
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if ((op_current_repeats_ + 1) % GetOpNumRepeatsPerEpoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
ep_step = 0;
}
// Propagate the eoe row to worker
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(new_row));
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::move(worker_job)));
UpdateRepeatAndEpochCounter();
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
}
// End() is commented out because it might never be called due to the lack of EOF when EpochCtrl is -1
// Handle eof logic, this code might never be reached if epoch_ctrl = -1.
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(new_row));
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::move(worker_job)));
// Quit all workers, this code might never be reached if EpochCtrl is -1.
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
TensorRow quit_flag(TensorRow::kFlagQuit);
auto quit = std::make_unique<MapWorkerJob>(quit_flag);
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(quit)));
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::move(quit)));
}
return Status::OK();
@ -195,31 +188,18 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
while (true) {
// Handle special logic where row carries a ctrl flag.
if (in_row.Flags() != TensorRow::kFlagNone) {
if (in_row.wait()) {
// When worker receives the signal from master thread, it increments a atomic int
// The last guy who increments the counter, wakes up master thread
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
// This will block the worker until master thread gives it a new work
} else if (in_row.eoe()) {
// Calling base class EoeReceived to forward eoe row.
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(in_row)));
} else if (in_row.eof()) {
// Calling base class EofReceived to forward eof row.
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(in_row)));
} else if (in_row.quit()) {
if (in_row.quit()) {
break;
}
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_row, &job_list));
continue;
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(in_row)));
} else {
CHECK_FAIL_RETURN_UNEXPECTED(in_row.size() != 0, "MapOp got an empty TensorRow.");
TensorRow out_row;
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
RETURN_IF_NOT_OK(WorkerCompute(in_row, &out_row, job_list));
// Push the row onto the connector for next operator to consume.
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(out_row)));
}
CHECK_FAIL_RETURN_UNEXPECTED(in_row.size() != 0, "MapOp got an empty TensorRow.");
TensorRow out_row;
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
RETURN_IF_NOT_OK(WorkerCompute(in_row, &out_row, job_list));
// Push the row onto the connector for next operator to consume.
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(out_row)));
// Fetch next data row and map job list
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_row, &job_list));
}
@ -398,10 +378,11 @@ Status MapOp::WaitForWorkers() {
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
// a special row (id=-1, empty, none flag) is used to signal that worker needs to pause.
TensorRow waitRow(TensorRow::kFlagWait);
RETURN_IF_NOT_OK(worker_in_queues_[wkr_id]->Add(std::make_unique<MapWorkerJob>(waitRow)));
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::make_unique<MapWorkerJob>(waitRow)));
}
// wait until all workers are done processing their work in local_queue_
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
next_worker_id_ = 0;
// clear the WaitPost for the next Wait()
wait_for_workers_post_.Clear();
return Status::OK();

View File

@ -45,7 +45,8 @@ class ParallelOp : public DatasetOp {
num_workers_(num_workers),
worker_connector_size_(op_connector_size),
num_workers_paused_(0),
epoch_sync_flag_(false) {
epoch_sync_flag_(false),
next_worker_id_(0) {
// reduce excessive memory usage with high parallelism
// when num_workers > 4, reduce op_connector_size to have similar total size if there were only 4 workers
constexpr int32_t worker_limit = 4;
@ -106,17 +107,62 @@ class ParallelOp : public DatasetOp {
virtual Status Collector() {
TaskManager::FindMe()->Post();
uint64_t ctr = 0;
// num_rows received, including eoe, num_step of current epoch
int64_t num_rows = 0, ep_step = 0, total_step = 0;
int32_t current_repeats = 0, current_epochs = 0;
TensorRow row;
do {
RETURN_IF_NOT_OK(worker_out_queues_[ctr++ % num_workers_]->PopFront(&row));
if (row.eoe() || row.eof() || !row.skip()) {
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
RETURN_IF_NOT_OK(worker_out_queues_[num_rows++ % num_workers_]->PopFront(&row));
if (row.wait()) {
// When collector receives the signal from workere thread, it increments a atomic int
// If num_worker signals are received, wakes up the main thread
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
num_rows = 0;
}
continue;
} else if (row.eoe()) {
current_repeats++;
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if (current_repeats % GetOpNumRepeatsPerEpoch() == 0) {
current_epochs++;
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(current_epochs, ep_step, total_step)));
ep_step = 0;
}
} else if (row.eof()) {
RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(current_epochs + 1, ep_step, total_step)));
} else if (row.skip()) {
continue;
} else if (row.Flags() == TensorRow::TensorRowFlags::kFlagNone) {
++ep_step;
++total_step;
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(current_epochs + 1, ep_step, total_step)));
}
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
} while (!row.eof());
return Status::OK();
}
/// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows.
/// Then it adds a new thread to the list.
/// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to
/// push rows to workers_in_queue
/// \return Status The status code returned
Status AddNewWorkers(int32_t num_new_workers = 1) override {
// wait for workers to process the current rows
RETURN_IF_NOT_OK(WaitForWorkers());
for (int32_t i = 0; i < num_new_workers; i++) {
worker_in_queues_.AddQueue(tree_->AllTasks());
worker_out_queues_.AddQueue(tree_->AllTasks());
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), nullptr, id()));
num_workers_++;
MS_LOG(INFO) << "A new worker has been added to op: " << Name() << "::" << id()
<< " num_workers=" << num_workers_;
}
return Status::OK();
}
// Wait post used to perform the pausing logic
WaitPost wait_for_workers_post_;
@ -128,6 +174,15 @@ class ParallelOp : public DatasetOp {
/// The number of worker threads
int32_t num_workers_;
int32_t NextWorkerID() {
int32_t next_worker = next_worker_id_;
next_worker_id_ = (next_worker_id_ + 1) % num_workers_;
return next_worker;
}
std::atomic_int next_worker_id_;
/// The size of input/output worker queeus
int32_t worker_connector_size_;
/// queues to hold the input rows to workers

View File

@ -96,11 +96,7 @@ Status MappableLeafOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) {
if (io_block->wait()) {
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagWait)));
} else if (io_block->eoe()) {
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
} else if (io_block->eof()) {

View File

@ -234,6 +234,11 @@ class QueueList {
~QueueList() = default;
Status AddQueue(TaskGroup *vg) {
queue_list_.emplace_back(std::make_unique<Queue<T>>(queue_list_[0]->capacity()));
return queue_list_[queue_list_.size() - 1]->Register(vg);
}
private:
// Queue contains non-copyable objects, so it cannot be added to a vector due to the vector
// requirement that objects must have copy semantics. To resolve this, we use a vector of unique

View File

@ -98,15 +98,21 @@ class TestCallback : public DSCallback {
bool IsNStepEndNeeded() override { return step_end_; }
std::vector<std::string> all_names(size_t len) {
return std::vector<std::string>(all_names_.begin(), all_names_.begin() + len);
std::vector<std::string> res(all_names_.begin(), all_names_.begin() + len);
std::sort(res.begin(), res.end());
return res;
}
std::vector<int64_t> all_step_nums(size_t len) {
return std::vector<int64_t>(all_step_nums_.begin(), all_step_nums_.begin() + len);
std::vector<int64_t> res(all_step_nums_.begin(), all_step_nums_.begin() + len);
std::sort(res.begin(), res.end());
return res;
}
std::vector<int64_t> all_ep_nums(size_t len) {
return std::vector<int64_t>(all_ep_nums_.begin(), all_ep_nums_.begin() + len);
std::vector<int64_t> res(all_ep_nums_.begin(), all_ep_nums_.begin() + len);
std::sort(res.begin(), res.end());
return res;
}
// flag for turning callback on and off
@ -179,6 +185,7 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
@ -219,7 +226,7 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(2);
// config EpochCtrlOp
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(-1);
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(2);
// start build then launch tree
leaf->SetTotalRepeats(-2);
leaf->SetNumRepeatsPerEpoch(2);
@ -246,6 +253,7 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND",
"EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 0, 1, 1, 5, 5, 8, 8, 9, 9, 13, 13, 16};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2};
@ -288,7 +296,7 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(2);
// config EpochCtrlOp
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(-1);
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(2);
// start build then launch tree
leaf->SetTotalRepeats(-2);
@ -316,6 +324,7 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
std::vector<std::string> callback_names = {"BGN", "SPBGN", "SPEND", "SPBGN", "SPEND",
"SPBGN", "SPEND", "SPBGN", "SPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 1, 1, 5, 5, 9, 9, 13, 13};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 2, 2, 2, 2};
@ -360,6 +369,7 @@ TEST_F(MindDataTestCallback, TestCAPICallback) {
ASSERT_OK(tree_adapter->GetNext(&row));
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called

View File

@ -79,6 +79,47 @@ class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd):
pass
def verify_events(events, epoch_num, step_num, step_size=1, map_num=1, repeat=1):
'''
Make sure that the events are in correct order.
* begin is the first
* epoch x begin before epoch x end
* epoch x end before epoch x+1 begin
* step x begin before step x end
* step x begin before step x+1 begin
* step x end before step x+1 end
'''
assert events[0][0] == "begin_0_0_0"
epochs = list(filter(lambda e: 'epoch' in e[0], events))
i = 0
while i < len(epochs):
epoch_num = epochs[i][0].split('_')[2]
e_type = epochs[i][0].split('_')[1]
assert str(i // 2 + 1) == epoch_num
assert e_type == "begin"
i += 1
epoch_num = epochs[i][0].split('_')[2]
e_type = epochs[i][0].split('_')[1]
assert str(i // 2 + 1) == epoch_num
assert e_type == "end"
i += 1
steps = list(filter(lambda e: 'step' in e[0], events))
steps = [(s[0].split('_')[1], s[0].split('_')[-1]) for s in steps]
steps_map = {}
max_step = 0
for s in steps:
if s[1] in steps_map:
assert steps_map[s[1]] == 'begin'
assert s[0] == 'end'
else:
assert s[0] == 'begin'
steps_map[s[1]] = 'begin'
assert int(s[1]) > max_step
max_step = max(max_step, int(s[1]))
def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
events = []
cb_id = list(range(map_num))
@ -121,6 +162,11 @@ def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
pass
expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
expected_events = [e[0] for e in expected_events]
verify_events(events, epochs, steps, step_size, repeat)
events = [e[0] for e in events]
expected_events.sort()
events.sort()
assert expected_events == events
@ -141,6 +187,9 @@ def build_test_case_2cbs(epochs, steps):
pass
expected_events = generate_expected(epochs, steps)
expected_events.sort()
events1.sort()
events2.sort()
assert expected_events == events1
assert expected_events == events2
@ -449,6 +498,12 @@ def test_callbacks_one_cb():
('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]),
('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]),
('epoch_end_2_4_8', [3])]
events1.sort()
events2.sort()
events3.sort()
expected_events1.sort()
expected_events2.sort()
expected_events3.sort()
assert events1 == expected_events1
assert events2 == expected_events2
assert events3 == expected_events3