Add callback to Batch op

This commit is contained in:
hesham 2021-11-10 23:27:45 -05:00
parent 2e174427c9
commit 2662b6d5c3
10 changed files with 165 additions and 173 deletions

View File

@ -78,9 +78,14 @@ BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size,
Status BatchOp::operator()() {
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
// Initialize callback
RETURN_IF_NOT_OK(callback_manager_.Init(this));
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
int64_t epoch_num = 0, batch_num = 0, cnt = 0;
int64_t ep_step = 0, total_step = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
TensorRow new_row;
std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>();
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
@ -88,11 +93,21 @@ Status BatchOp::operator()() {
int32_t cur_batch_size = 0;
RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0)));
while (child_iterator_->EofHandled() == false) {
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
ep_step = 0;
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
while (new_row.empty() == false) {
// we only call stepBegin when a new batch is starting to be filled.
if (table->size() == 0) {
ep_step++;
total_step++;
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
table->emplace_back(new_row);
// if # of rows is enough to make 1 batch, send it to worker_queue
if (table->size() == static_cast<size_t>(cur_batch_size)) {
RETURN_IF_NOT_OK(worker_in_queues_[cnt % num_workers_]->EmplaceBack(
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->EmplaceBack(
std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt + 1 - epoch_num))));
cnt++;
table = std::make_unique<TensorQTable>();
@ -102,7 +117,7 @@ Status BatchOp::operator()() {
}
// Reminder logic, execute only when there is a remainder (table is non empty) and don't drop
if (drop_ == false && table->empty() == false) {
RETURN_IF_NOT_OK(worker_in_queues_[cnt % num_workers_]->EmplaceBack(
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->EmplaceBack(
std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt + 1 - epoch_num))));
cnt++;
}
@ -111,7 +126,7 @@ Status BatchOp::operator()() {
batch_num = 0;
epoch_num++;
RETURN_IF_NOT_OK(
worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE))));
worker_in_queues_[NextWorkerID()]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE))));
RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num)));
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
@ -123,13 +138,13 @@ Status BatchOp::operator()() {
<< "reduce memory usage.";
}
#endif
UpdateRepeatAndEpochCounter();
} // end of EofHandled() == false
RETURN_IF_NOT_OK(
worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF))));
worker_in_queues_[NextWorkerID()]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF))));
// EOF received, send quit signal to all workers
for (int32_t ind = 0; ind < num_workers_; ind++) {
RETURN_IF_NOT_OK(
worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit))));
RETURN_IF_NOT_OK(SendQuitFlagToWorker(NextWorkerID()));
}
return Status::OK();
}
@ -507,7 +522,9 @@ Status BatchOp::ComputeColMap() {
// column names are unchanged
if (col_name_flag) {
if (out_col_names_.empty()) out_col_names_ = in_col_names_;
if (out_col_names_.empty()) {
out_col_names_ = in_col_names_;
}
column_name_id_map_ = child_map_;
return Status::OK();
}
@ -575,15 +592,14 @@ Status BatchOp::GetNextRowPullMode(TensorRow *const row) {
}
return Status::OK();
}
Status BatchOp::WaitForWorkers() {
num_workers_paused_ = 0;
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(worker_in_queues_[i]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kWait))));
}
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
wait_for_workers_post_.Clear();
Status BatchOp::SendWaitFlagToWorker(int32_t worker_id) {
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kWait))));
return Status::OK();
}
Status BatchOp::SendQuitFlagToWorker(int32_t worker_id) {
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit))));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -255,7 +255,10 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
// @return Status The status code returned
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
#endif
Status WaitForWorkers() override;
Status SendWaitFlagToWorker(int32_t worker_id) override;
Status SendQuitFlagToWorker(int32_t worker_id) override;
Status ComputeColMap() override;
int32_t start_batch_size_;

View File

@ -338,6 +338,8 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
virtual Status SendQuitFlagToWorker(int32_t worker_id) { return Status::OK(); }
virtual Status SendWaitFlagToWorker(int32_t worker_id) { return Status::OK(); }
// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }

View File

@ -371,25 +371,15 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
}
}
Status MapOp::WaitForWorkers() {
// reset num_paused workers to 0
num_workers_paused_ = 0;
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
// a special row (id=-1, empty, none flag) is used to signal that worker needs to pause.
TensorRow waitRow(TensorRow::kFlagWait);
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::make_unique<MapWorkerJob>(waitRow)));
}
// wait until all workers are done processing their work in local_queue_
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
next_worker_id_ = 0;
// clear the WaitPost for the next Wait()
wait_for_workers_post_.Clear();
Status MapOp::SendWaitFlagToWorker(int32_t worker_id) {
TensorRow wait_row(TensorRow::kFlagWait);
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::make_unique<MapWorkerJob>(wait_row)));
return Status::OK();
}
Status MapOp::SendQuitFlagToWorker(int32_t worker_id) {
TensorRow quit_flag(TensorRow::kFlagQuit);
auto quit = std::make_unique<MapWorkerJob>(quit_flag);
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::move(quit)));
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::make_unique<MapWorkerJob>(quit_flag)));
return Status::OK();
}
} // namespace dataset

View File

@ -174,12 +174,10 @@ class MapOp : public ParallelOp<std::unique_ptr<MapWorkerJob>, TensorRow> {
// @return - Status
Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map);
// This function should only be called from master thread. It intends to suspend the operation of all workers and
// have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost.
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
// who does the increment wakes up the master.
// @return - Status
Status WaitForWorkers() override;
/// Send wait flag row to worker at worker_id to make it wait
/// \param worker_id id of the worker
/// \return Status code
Status SendWaitFlagToWorker(int32_t worker_id) override;
/// Send quit flag row to worker at worker_id to make it exit
/// \param worker_id id of the worker

View File

@ -145,6 +145,20 @@ class ParallelOp : public DatasetOp {
return Status::OK();
}
Status WaitForWorkers() {
// reset num_paused workers to 0
num_workers_paused_ = 0;
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
RETURN_IF_NOT_OK(SendWaitFlagToWorker(NextWorkerID()));
}
// wait until all workers are done processing their work in local_queue_
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
next_worker_id_ = 0;
// clear the WaitPost for the next Wait()
wait_for_workers_post_.Clear();
return Status::OK();
}
/// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows.
/// Then it adds a new thread to the list.
/// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to

View File

@ -33,11 +33,11 @@ Status MappableLeafOp::operator()() {
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(InitOp());
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
int64_t ep_step = 0, total_step = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
while (true) { // each iteration is 1 epoch, breaks when IsLastIteration() is true
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
ep_step = 0;
@ -70,7 +70,7 @@ Status MappableLeafOp::operator()() {
}
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
for (int32_t i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(SendQuitFlagToWorker(i));
RETURN_IF_NOT_OK(SendQuitFlagToWorker(NextWorkerID()));
}
return Status::OK();
}
@ -113,19 +113,14 @@ Status MappableLeafOp::WorkerEntry(int32_t worker_id) {
}
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Unexpected nullptr received in worker.");
}
Status MappableLeafOp::WaitForWorkers() {
num_workers_paused_ = 0;
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(worker_in_queues_[NextWorkerID()]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait)));
}
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
next_worker_id_ = 0;
wait_for_workers_post_.Clear();
Status MappableLeafOp::SendWaitFlagToWorker(int32_t worker_id) {
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait)));
return Status::OK();
}
Status MappableLeafOp::SendQuitFlagToWorker(int32_t worker_id) {
RETURN_IF_NOT_OK(
worker_in_queues_[worker_id]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockNone)));
return Status::OK();
}
} // namespace dataset

View File

@ -96,7 +96,7 @@ class MappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>, p
/// Reset function to be called after every epoch to reset the source op after
/// \return Status The status code returned
Status Reset() override;
Status WaitForWorkers() override;
Status SendWaitFlagToWorker(int32_t worker_id) override;
Status SendQuitFlagToWorker(int32_t worker_id) override;
};
} // namespace dataset

View File

@ -130,131 +130,97 @@ class MindDataTestCallback : public UT::DatasetOpTesting {
DatasetOpTesting::SetUp();
GlobalInit();
}
};
TEST_F(MindDataTestCallback, TestBasicCallback) {
MS_LOG(INFO) << "Doing: MindDataTestCallback-TestBasicCallback";
// config callback
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
std::shared_ptr<DSCallback> cb1 = tst_cb;
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
ASSERT_OK(schema->AddColumn(col));
void TestBasicCallback(std::shared_ptr<ExecutionTree> tree, std::shared_ptr<DatasetOp> callback_node,
int32_t step_size) {
// config callback
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(step_size);
std::shared_ptr<DSCallback> cb1 = tst_cb;
std::vector<std::shared_ptr<DSCallback>> cbs = {};
cbs.push_back(cb1);
callback_node->AddCallbacks(std::move(cbs));
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
int32_t op_connector_size = config_manager->op_connector_size();
int32_t num_workers = config_manager->num_parallel_workers();
std::shared_ptr<RandomDataOp> leaf =
std::make_shared<RandomDataOp>(num_workers, op_connector_size, 44, std::move(schema));
// config mapOp
std::vector<std::string> input_columns = {"label"};
std::vector<std::string> output_columns = {};
std::vector<std::shared_ptr<TensorOp>> op_list;
std::shared_ptr<TensorOp> my_no_op = std::make_shared<NoOp>();
op_list.push_back(my_no_op);
std::shared_ptr<MapOp> map_op =
std::make_shared<MapOp>(input_columns, output_columns, std::move(op_list), num_workers, op_connector_size);
std::vector<std::shared_ptr<DSCallback>> cbs = {};
cbs.push_back(cb1);
map_op->AddCallbacks(std::move(cbs));
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(2);
// start build then launch tree
leaf->SetTotalRepeats(2);
leaf->SetNumRepeatsPerEpoch(2);
map_op->SetTotalRepeats(2);
map_op->SetNumRepeatsPerEpoch(2);
std::shared_ptr<ExecutionTree> tree = Build({leaf, map_op, repeat_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorMap tensor_map;
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (!tensor_map.empty()) {
ASSERT_OK(tree->Prepare());
ASSERT_OK(tree->Launch());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorMap tensor_map;
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
while (!tensor_map.empty()) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
size_t len = 7;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
size_t len = 7;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
}
std::vector<std::shared_ptr<DatasetOp>> GenerateNodes() {
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
EXPECT_OK(schema->AddColumn(col));
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
int32_t op_connector_size = config_manager->op_connector_size();
int32_t num_workers = config_manager->num_parallel_workers();
int32_t num_rows = 44;
std::shared_ptr<RandomDataOp> leaf =
std::make_shared<RandomDataOp>(num_workers, op_connector_size, num_rows, std::move(schema));
// config mapOp
std::vector<std::string> input_columns = {"label"};
std::vector<std::string> output_columns = {};
std::vector<std::shared_ptr<TensorOp>> op_list;
std::shared_ptr<TensorOp> my_no_op = std::make_shared<NoOp>();
op_list.push_back(my_no_op);
std::shared_ptr<MapOp> map_op =
std::make_shared<MapOp>(input_columns, output_columns, std::move(op_list), num_workers, op_connector_size);
PadInfo pad_map;
std::shared_ptr<BatchOp> batch_op =
std::make_shared<BatchOp>(1, false, false, op_connector_size, num_workers, std::vector<std::string>{}, pad_map);
// config RepeatOp
int32_t num_repeats = 2;
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(num_repeats);
// start build then launch tree
leaf->SetTotalRepeats(num_repeats);
leaf->SetNumRepeatsPerEpoch(num_repeats);
map_op->SetTotalRepeats(num_repeats);
map_op->SetNumRepeatsPerEpoch(num_repeats);
batch_op->SetTotalRepeats(num_repeats);
batch_op->SetNumRepeatsPerEpoch(num_repeats);
return {leaf, map_op, batch_op, repeat_op};
}
};
/// Feature: Callback
/// Description: Test callbacks with mappable dataset (RandomDataset)
/// Expectation: number and order of callbacks generated are correct
TEST_F(MindDataTestCallback, TestMappableBasicCallback) {
MS_LOG(INFO) << "Doing: MindDataTestCallback-TestMappableBasicCallback";
// config callback
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
std::shared_ptr<DSCallback> cb1 = tst_cb;
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
ASSERT_OK(schema->AddColumn(col));
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
int32_t op_connector_size = config_manager->op_connector_size();
int32_t num_workers = config_manager->num_parallel_workers();
std::shared_ptr<RandomDataOp> leaf =
std::make_shared<RandomDataOp>(num_workers, op_connector_size, 44, std::move(schema));
// config mapOp
std::vector<std::string> input_columns = {"label"};
std::vector<std::string> output_columns = {};
std::vector<std::shared_ptr<TensorOp>> op_list;
std::shared_ptr<TensorOp> my_no_op = std::make_shared<NoOp>();
op_list.push_back(my_no_op);
std::shared_ptr<MapOp> map_op =
std::make_shared<MapOp>(input_columns, output_columns, std::move(op_list), num_workers, op_connector_size);
std::vector<std::shared_ptr<DSCallback>> cbs = {};
cbs.push_back(cb1);
leaf->AddCallbacks(std::move(cbs));
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(2);
// start build then launch tree
leaf->SetTotalRepeats(2);
leaf->SetNumRepeatsPerEpoch(2);
map_op->SetTotalRepeats(2);
map_op->SetNumRepeatsPerEpoch(2);
std::shared_ptr<ExecutionTree> tree = Build({leaf, map_op, repeat_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorMap tensor_map;
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
while (!tensor_map.empty()) {
rc = di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::sort(callback_names.begin(), callback_names.end());
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
size_t len = 7;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
TEST_F(MindDataTestCallback, TestBasicCallback) {
MS_LOG(INFO) << "Doing: MindDataTestCallback-TestBasicCallback";
// Test Mapop
auto nodes = GenerateNodes();
auto tree = Build(nodes);
TestBasicCallback(tree, nodes[1], 64);
// Test LeafOp
nodes = GenerateNodes();
tree = Build(nodes);
TestBasicCallback(tree, nodes[0], 64);
// Test BatchOp
nodes = GenerateNodes();
tree = Build(nodes);
TestBasicCallback(tree, nodes[2], 64);
}
TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
@ -285,10 +251,11 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
map_op->AddCallbacks(std::move(cbs));
EXPECT_TRUE(rc.IsOk());
int32_t num_repeats = 2;
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(2);
std::shared_ptr<RepeatOp> repeat_op = std::make_shared<RepeatOp>(num_repeats);
// config EpochCtrlOp
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(2);
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(num_repeats);
// start build then launch tree
leaf->SetTotalRepeats(-2);
leaf->SetNumRepeatsPerEpoch(2);

View File

@ -167,6 +167,8 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeModifier) {
ASSERT_NE(to_number, nullptr);
ds = ds->Map({to_number}, {"col1"}, {"col1"});
ds->SetNumWorkers(1);
ds = ds->Batch(1);
ds->SetNumWorkers(1);
auto tree_adapter = std::make_shared<TreeAdapter>();
// Disable IR optimization pass
@ -174,10 +176,15 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeModifier) {
ASSERT_OK(tree_adapter->Compile(ds->IRNode(), 1));
auto tree_modifier = std::make_unique<TreeModifier>(tree_adapter.get());
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>(2));
tree_modifier->AddChangeRequest(0, std::make_shared<ResizeConnectorRequest>(20));
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>());
tree_modifier->AddChangeRequest(1, std::make_shared<ChangeNumWorkersRequest>(2));
tree_modifier->AddChangeRequest(1, std::make_shared<ChangeNumWorkersRequest>());
tree_modifier->AddChangeRequest(1, std::make_shared<ChangeNumWorkersRequest>(10));
tree_modifier->AddChangeRequest(1, std::make_shared<ResizeConnectorRequest>(20));
tree_modifier->AddChangeRequest(0, std::make_shared<ResizeConnectorRequest>(100));
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>(2));
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>());
tree_modifier->AddChangeRequest(0, std::make_shared<ChangeNumWorkersRequest>(10));
std::vector<int32_t> expected_result = {1, 5, 9, 1, 5, 9};
@ -189,7 +196,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeModifier) {
while (row.size() != 0) {
auto tensor = row[0];
int32_t num;
ASSERT_OK(tensor->GetItemAt(&num, {}));
ASSERT_OK(tensor->GetItemAt(&num, {0}));
EXPECT_EQ(num, expected_result[i]);
ASSERT_OK(tree_adapter->GetNext(&row));
i++;