diff --git a/tests/ut/cpp/dataset/ir_callback_test.cc b/tests/ut/cpp/dataset/ir_callback_test.cc index 6ff41320d68..9872ff491df 100644 --- a/tests/ut/cpp/dataset/ir_callback_test.cc +++ b/tests/ut/cpp/dataset/ir_callback_test.cc @@ -52,36 +52,42 @@ class TestCallback : public DSCallback { } Status DSBegin(const CallbackParam &cb_param) override { + std::lock_guard guard(lock_); all_names_.push_back("BGN"); all_step_nums_.push_back(cb_param.cur_step_num_); all_ep_nums_.push_back(cb_param.cur_epoch_num_); return Status::OK(); } Status DSEpochBegin(const CallbackParam &cb_param) override { + std::lock_guard guard(lock_); all_names_.push_back("EPBGN"); all_step_nums_.push_back(cb_param.cur_step_num_); all_ep_nums_.push_back(cb_param.cur_epoch_num_); return Status::OK(); } Status DSNStepBegin(const CallbackParam &cb_param) override { + std::lock_guard guard(lock_); all_names_.push_back("SPBGN"); all_step_nums_.push_back(cb_param.cur_step_num_); all_ep_nums_.push_back(cb_param.cur_epoch_num_); return Status::OK(); } Status DSEnd(const CallbackParam &cb_param) override { + std::lock_guard guard(lock_); all_names_.push_back("END"); all_step_nums_.push_back(cb_param.cur_step_num_); all_ep_nums_.push_back(cb_param.cur_epoch_num_); return Status::OK(); } Status DSEpochEnd(const CallbackParam &cb_param) override { + std::lock_guard guard(lock_); all_names_.push_back("EPEND"); all_step_nums_.push_back(cb_param.cur_step_num_); all_ep_nums_.push_back(cb_param.cur_epoch_num_); return Status::OK(); } Status DSNStepEnd(const CallbackParam &cb_param) override { + std::lock_guard guard(lock_); all_names_.push_back("SPEND"); all_step_nums_.push_back(cb_param.cur_step_num_); all_ep_nums_.push_back(cb_param.cur_epoch_num_); @@ -118,6 +124,7 @@ class TestCallback : public DSCallback { // name of the callback function in sequence, BGN, EPBGN, SPB, END, EPEND, SPEND std::vector all_names_; std::vector all_step_nums_, all_ep_nums_; + std::mutex lock_; }; } // namespace test @@ -298,56 +305,28 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { // config callback Status rc; std::shared_ptr tst_cb = std::make_shared(4); - std::shared_ptr cb1 = tst_cb; // turn off the epochs tst_cb->epoch_begin_ = false; tst_cb->epoch_end_ = false; - - // config leaf_op, use random_data to avoid I/O - std::shared_ptr config_manager = GlobalContext::config_manager(); - int32_t op_connector_size = config_manager->op_connector_size(); - int32_t num_workers = config_manager->num_parallel_workers(); - - std::unique_ptr schema = std::make_unique(); - TensorShape shape({}); // empty shape is a 1-value scalar Tensor - ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); - ASSERT_OK(schema->AddColumn(col)); - std::shared_ptr leaf = std::make_shared(4, op_connector_size, 4, std::move(schema)); + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("label", mindspore::DataType::kNumberTypeUInt32, {})); + std::shared_ptr ds = RandomData(4, schema); + ASSERT_NE(ds, nullptr); + ds->SetNumWorkers(1); // config mapOp - std::vector input_columns = {"label"}; - std::vector output_columns = {}; - std::vector> op_list; - std::shared_ptr my_no_op = std::make_shared(); - op_list.push_back(my_no_op); - std::shared_ptr map_op = - std::make_shared(input_columns, output_columns, std::move(op_list), num_workers, op_connector_size); - map_op->AddCallbacks({cb1}); - // config RepeatOp - std::shared_ptr repeat_op = std::make_shared(2); - // config EpochCtrlOp - std::shared_ptr epoch_ctrl_op = std::make_shared(2); - - // start build then launch tree - leaf->SetTotalRepeats(4); - leaf->SetNumRepeatsPerEpoch(2); - map_op->SetTotalRepeats(4); - map_op->SetNumRepeatsPerEpoch(2); - std::shared_ptr tree = Build({leaf, map_op, repeat_op, epoch_ctrl_op}); - rc = tree->Prepare(); - EXPECT_TRUE(rc.IsOk()); - rc = tree->Launch(); - EXPECT_TRUE(rc.IsOk()); - // Start the loop of reading tensors from our pipeline - DatasetIterator di(tree); - TensorMap tensor_map; - size_t num_epochs = 2; + ds = ds->Map({std::make_shared(mindspore::DataType::kNumberTypeUInt64)}, {"label"}, {}, {}, + nullptr, {tst_cb}); + ds->SetNumWorkers(1); + ASSERT_NE(ds, nullptr); + ds = ds->Repeat(2); + ASSERT_NE(ds, nullptr); + int32_t num_epochs = 2; + auto itr = ds->CreateIterator({}, num_epochs); for (int ep_num = 0; ep_num < num_epochs; ++ep_num) { - ASSERT_OK(di.GetNextAsMap(&tensor_map)); - EXPECT_TRUE(rc.IsOk()); - - while (tensor_map.size() != 0) { - rc = di.GetNextAsMap(&tensor_map); - EXPECT_TRUE(rc.IsOk()); + std::unordered_map row; + ASSERT_OK(itr->GetNextRow(&row)); + while (!row.empty()) { + ASSERT_OK(itr->GetNextRow(&row)); } }