diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h index 4af185fb50e..095d36f31cb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/python_tree_consumer.h @@ -49,7 +49,7 @@ class PythonPullBasedIteratorConsumer : public PullBasedIteratorConsumer { public: /// Constructor which will call the base class default constructor. /// \param num_epochs number of epochs. Default to -1 (infinite epochs). - explicit PythonPullBasedIteratorConsumer(int32_t num_epochs = -1) : PullBasedIteratorConsumer() {} + explicit PythonPullBasedIteratorConsumer(int32_t num_epochs = -1) : PullBasedIteratorConsumer(num_epochs) {} ~PythonPullBasedIteratorConsumer() = default; /// Returns the next row in a vector format diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 62b43395091..6af1c8c13a4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -678,6 +678,12 @@ int64_t BatchOp::GetTreeBatchSize() { Status BatchOp::GetNextRowPullMode(TensorRow *const row) { RETURN_UNEXPECTED_IF_NULL(row); + if (eoe_received_) { + UpdateRepeatAndEpochCounter(); + *row = TensorRow(TensorRow::kFlagEOE); + eoe_received_ = false; + return Status::OK(); + } std::unique_ptr table = std::make_unique(); child_iterator_ = std::make_unique(this, 0, 0); int32_t cur_batch_size = 0; @@ -685,6 +691,16 @@ Status BatchOp::GetNextRowPullMode(TensorRow *const row) { for (int i = 0; i < cur_batch_size; i++) { TensorRow new_row; RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(&new_row)); + if (new_row.eoe()) { + if (!drop_) { + eoe_received_ = true; + } else { + *row = new_row; + UpdateRepeatAndEpochCounter(); + return Status::OK(); + } + break; + } if (!new_row.empty()) { table->emplace_back(new_row); if (table->size() == static_cast(cur_batch_size)) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h index 82c7bdf1e5e..9b27c30dcb6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -315,6 +315,9 @@ class BatchOp : public ParallelOp, CBatc /// \brief Gets the implementation status for operator in pull mode /// \return implementation status ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; } + + private: + bool eoe_received_ = false; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc index 971f3f8ee4d..ac2150116f7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -79,5 +79,24 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) { } int64_t EpochCtrlOp::GetTreeRepeatCount() { return child_[0]->GetTreeRepeatCount(); } + +Status EpochCtrlOp::GetNextRowPullMode(TensorRow *row) { + RETURN_UNEXPECTED_IF_NULL(row); + if (child_.empty()) { + RETURN_STATUS_UNEXPECTED("[Internal ERROR] EpochCtrlOp can't be the leaf node(first operator) of pipeline."); + } + + // `retry_if_eoe` is false because EpochCtrlOp does not eat EOE. + RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row)); + + // Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op. + // Other TensorRows containing data or EOF will simply be forwarded. + // EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up. + if (row->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(0)); + } + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h index c1555b41b87..2d4d5541b2b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h @@ -49,6 +49,11 @@ class EpochCtrlOp : public RepeatOp { Status EoeReceived(int32_t worker_id) override; int64_t GetTreeRepeatCount() override; + + /// \brief In pull mode, gets the next row + /// \param row[out] - Fetched TensorRow + /// \return Status The status code returned + Status GetNextRowPullMode(TensorRow *const row) override; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index 784181316d9..e5f3acd94d6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -484,7 +484,11 @@ std::vector MapOp::GetMPWorkerPIDs() const { Status MapOp::GetNextRowPullMode(TensorRow *const row) { TensorRow new_row; RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(&new_row)); + if (new_row.eoe()) { + UpdateRepeatAndEpochCounter(); + } if (new_row.empty()) { + (*row) = std::move(new_row); return Status::OK(); } auto column_name_id_map = child_[0]->column_name_id_map(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 8e02389e814..30b5135f74b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -130,5 +130,28 @@ Status RepeatOp::Reset() { } int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; } + +Status RepeatOp::GetNextRowPullMode(TensorRow *const row) { + RETURN_UNEXPECTED_IF_NULL(row); + if (child_.empty()) { + RETURN_STATUS_UNEXPECTED( + "[Internal ERROR] Pipeline init failed, RepeatOp can't be the leaf node(first operator) of pipeline."); + } + RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row)); + // Loop until non EOE is received + while (row->eoe()) { + MS_LOG(INFO) << "RepeatOp::GetNextRowPullMode eoe received."; + RETURN_IF_NOT_OK(EoeReceived(0)); + if (state_ == OpState::kDeOpIdle) { + return Status::OK(); + } + RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row)); + } + // Check if the last buf is next eof + if (row->eof()) { + RETURN_IF_NOT_OK(EofReceived(0)); + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index b6d4af1dc56..0a8b7a9d3a7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -97,6 +97,11 @@ class RepeatOp : public PipelineOp { std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. + /// \brief In pull mode, gets the next row + /// \param row[out] - Fetched TensorRow + /// \return Status The status code returned + Status GetNextRowPullMode(TensorRow *const row) override; + protected: // The number of repeats that the user requested. // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. @@ -107,6 +112,10 @@ class RepeatOp : public PipelineOp { // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. int32_t repeat_count_; + + /// \brief Gets the implementation status for operator in pull mode + /// \return implementation status + ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; } }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc index 18cfe7195da..fc285d53cbb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -78,6 +78,9 @@ Status SkipOp::GetNextRowPullMode(TensorRow *const row) { bool eoe_received = false; while (skip_count_ < max_skips_) { RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row)); + if (row->eof()) { + return Status::OK(); + } if (row->eoe() && !once_only_) { eoe_received = true; break; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc index dd0f7dba844..f94e77a8c20 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc @@ -22,7 +22,11 @@ namespace mindspore { namespace dataset { MappableLeafOp::MappableLeafOp(int32_t num_wkrs, int32_t queue_size, std::shared_ptr sampler) - : ParallelOp(num_wkrs, queue_size, std::move(sampler)), sample_ids_(nullptr), curr_row_(0), prepared_data_{false} {} + : ParallelOp(num_wkrs, queue_size, std::move(sampler)), + sample_ids_(nullptr), + curr_row_(0), + prepared_data_{false}, + eof_handled_{false} {} #ifdef ENABLE_PYTHON Status MappableLeafOp::ImageDecrypt(const std::string &path, std::shared_ptr *tensor, @@ -110,6 +114,7 @@ Status MappableLeafOp::operator()() { Status MappableLeafOp::Reset() { MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); + curr_row_ = 0; return Status::OK(); } @@ -166,15 +171,21 @@ Status MappableLeafOp::GetNextRowPullMode(TensorRow *const row) { RETURN_IF_NOT_OK(InitPullMode()); prepared_data_ = true; } + if (eof_handled_) { + *row = TensorRow(TensorRow::kFlagEOF); + return Status::OK(); + } + TensorRow sample_row; if (sample_ids_ == nullptr) { RETURN_IF_NOT_OK(this->InitSampler()); - TensorRow sample_row; RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); CHECK_FAIL_RETURN_UNEXPECTED(sample_row.size() > 0, "GetNextRowPullMode: Expect at least one sample in sampler."); sample_ids_ = sample_row[0]; } if (curr_row_ + 1 > sample_ids_->Size()) { *row = TensorRow(TensorRow::kFlagEOE); + RETURN_IF_NOT_OK(ResetAndUpdateRepeat()); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); return Status::OK(); } int64_t key; @@ -183,5 +194,15 @@ Status MappableLeafOp::GetNextRowPullMode(TensorRow *const row) { curr_row_++; return Status::OK(); } + +Status MappableLeafOp::ResetAndUpdateRepeat() { + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); + UpdateRepeatAndEpochCounter(); + } else { + eof_handled_ = true; + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.h index 042005d595e..2a12dec6ed2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.h @@ -85,6 +85,7 @@ class MappableLeafOp : public ParallelOp, TensorRow>, p TensorPtr sample_ids_; // sample id pointer for pull mode uint32_t curr_row_; // current row number count for pull mode bool prepared_data_; // flag to indicate whether the data is prepared before LoadTensorRow for pull mode + bool eof_handled_; // T/F if this op got an eof /// Initialize Sampler, calls sampler->Init() within /// @return Status The status code returned @@ -126,6 +127,11 @@ class MappableLeafOp : public ParallelOp, TensorRow>, p /// \return Status The status code returned virtual Status LoadTensorRowPullMode(row_id_type row_id, TensorRow *row) { return LoadTensorRow(row_id, row); } + /// reset the op and update repeat and epoch number if the condition is met. + /// \param row[out] - Fetched EOF if it is the last iteration for epoch + /// \return Status The status code returned + virtual Status ResetAndUpdateRepeat(); + /// \brief Gets the implementation status for operator in pull mode /// \return implementation status ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.cc index f9d2fd8c86b..df838be9539 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.cc @@ -17,6 +17,9 @@ #include "minddata/dataset/engine/tree_adapter_lite.h" #include "minddata/dataset/engine/ir/datasetops/root_node.h" #include "minddata/dataset/engine/opt/pass.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/engine/opt/post/repeat_pass.h" +#endif #include "minddata/dataset/engine/opt/pre/debug_mode_pass.h" #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" @@ -73,6 +76,10 @@ Status TreeAdapterLite::BuildTree(std::shared_ptr root_ir) { Status TreeAdapterLite::GetNextRow(TensorRow *const row) { RETURN_UNEXPECTED_IF_NULL(root_); RETURN_IF_NOT_OK(root_->GetNextRowPullMode(row)); + if (row->eof()) { + std::string err = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."; + RETURN_STATUS_UNEXPECTED(err); + } RETURN_UNEXPECTED_IF_NULL(row); return Status::OK(); } @@ -97,6 +104,20 @@ Status TreeAdapterLite::PrePass(std::shared_ptr ir) const { return Status::OK(); } +Status TreeAdapterLite::PostPass(std::shared_ptr ir) const { + RETURN_UNEXPECTED_IF_NULL(ir); + // Vector of actions in post-pass phase + std::vector> actions; +#ifndef ENABLE_ANDROID + MS_LOG(INFO) << "Running repeat pass."; + (void)actions.emplace_back(std::make_unique()); + bool modified = false; + RETURN_IF_NOT_OK(actions[0]->Run(ir, &modified)); + MS_LOG(INFO) << "Repeat pass completed."; +#endif + return Status::OK(); +} + Status TreeAdapterLite::Compile(const std::shared_ptr &input_ir, int32_t num_epochs) { RETURN_UNEXPECTED_IF_NULL(input_ir); input_ir_ = input_ir; @@ -119,6 +140,9 @@ Status TreeAdapterLite::Compile(const std::shared_ptr &input_ir, in // Pre-pass of the IR tree RETURN_IF_NOT_OK(PrePass(root_ir)); MS_LOG(INFO) << "Plan after PrePass:" << '\n' << *root_ir << '\n'; + + RETURN_IF_NOT_OK(PostPass(root_ir)); + MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n'; root_ir_ = root_ir; RETURN_IF_NOT_OK(BuildTree(root_ir)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.h index 580b7511d95..d517f3db59e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter_lite.h @@ -51,6 +51,9 @@ class TreeAdapterLite { // Run the mandatory pass checking the syntax and semantics of the IR tree Status PrePass(std::shared_ptr ir) const; + // Run the mandatory pass augmenting the IR tree + Status PostPass(std::shared_ptr ir) const; + std::shared_ptr input_ir_; std::shared_ptr root_ir_; diff --git a/tests/ut/python/dataset/test_pipeline_debug_mode.py b/tests/ut/python/dataset/test_pipeline_debug_mode.py index 4bf742f32df..4b31891c5c1 100644 --- a/tests/ut/python/dataset/test_pipeline_debug_mode.py +++ b/tests/ut/python/dataset/test_pipeline_debug_mode.py @@ -24,7 +24,6 @@ from mindspore import log as logger # the global configuration setting of debug_mode may impact other tests running in parallel. pytestmark = pytest.mark.forked -DATA_DIR_10 = "../data/dataset/testCifar10Data" DEBUG_MODE = False SEED_VAL = 0 # seed will be set internally in debug mode, save original seed value to restore. @@ -163,17 +162,28 @@ def test_pipeline_debug_mode_concat(): """ logger.info("test_pipeline_debug_mode_concat") data_dir = "../data/dataset/testCelebAData/" + num_repeat = 3 data1 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0) data2 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0) data3 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0) data4 = data1.concat(data2) data5 = data3 + data4 - num_rows = 0 - for item1 in data5.create_tuple_iterator(num_epochs=1): - assert len(item1) == 2 - assert item1[0].shape == (2268, 4032, 3) - num_rows += 1 - assert num_rows == 12 + data5 = data5.repeat(num_repeat) + num_epoch = 2 + epoch_count = 0 + sample_row = 12 + sample_count = 0 + for _ in range(num_epoch): + num_rows = 0 + for item1 in data5.create_tuple_iterator(num_epochs=1): + assert len(item1) == 2 + assert item1[0].shape == (2268, 4032, 3) + num_rows += 1 + epoch_count += 1 + sample_count += num_rows + assert num_rows == sample_row * num_repeat + assert epoch_count == num_epoch + assert sample_count == num_repeat * num_epoch * sample_row def test_pipeline_debug_mode_map_random(): @@ -223,28 +233,63 @@ def test_pipeline_debug_mode_imdb_shuffle(): Expectation: The data is processed successfully in the same order. """ logger.info("test_pipeline_debug_mode_imdb_shuffle") - buffer_size = 5 # apply dataset operations data1 = ds.IMDBDataset("../data/dataset/testIMDBDataset", shuffle=True) - data1 = data1.shuffle(buffer_size=buffer_size) # Verify dataset size data1_size = data1.get_dataset_size() logger.info("dataset size is: {}".format(data1_size)) assert data1_size == 8 - + expect_output = [["train_pos_1.txt", 1], ["train_pos_0.txt", 1], ["train_neg_0.txt", 0], ["test_pos_1.txt", 1], [ + "test_neg_1.txt", 0], ["test_pos_0.txt", 1], ["test_neg_0.txt", 0], ["train_neg_1.txt", 0]] num_iter = 0 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary # in this example, each dictionary has keys "text" and "label" logger.info("text is {}".format(item["text"])) + assert item["text"] == expect_output[num_iter][0] logger.info("label is {}".format(item["label"])) + assert item["label"] == expect_output[num_iter][1] num_iter += 1 logger.info("Number of data in data1: {}".format(num_iter)) assert num_iter == 8 +def test_pipeline_debug_mode_multi_epoch_map_pyfunc(): + """ + Feature: Pipeline debug mode. + Description: Test creating dict iterator with map(PyFunc) with num_epochs > 1. + Expectation: Successful. + """ + logger.info("test_pipeline_debug_mode_multi_epoch_map_pyfunc") + data = ds.CelebADataset("../data/dataset/testCelebAData/", sampler=ds.SequentialSampler(), + decode=True) + num_repeat = 5 + sample_row = 4 + data = data.repeat(num_repeat) + data = data.map(operations=[(lambda x: x - 1), (lambda x: x * 2)], input_columns=["image"]) + num_epoch = 7 + epoch_count = 0 + sample_count = 0 + iter1 = data.create_dict_iterator(num_epochs=num_epoch) + for _ in range(num_epoch): + num_rows = 0 + for item in iter1: + assert len(item) == 2 + assert item["image"].shape == (2268, 4032, 3) + num_rows += 1 + assert num_rows == sample_row * num_repeat + sample_count += num_rows + epoch_count += 1 + assert epoch_count == num_epoch + assert sample_count == num_repeat * num_epoch * sample_row + + err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs." + with pytest.raises(RuntimeError, match=err_msg): + iter1.__next__() + + if __name__ == '__main__': setup_function() test_pipeline_debug_mode_tuple() @@ -257,4 +302,5 @@ if __name__ == '__main__': test_pipeline_debug_mode_shuffle() test_pipeline_debug_mode_map_random() test_pipeline_debug_mode_imdb_shuffle() + test_pipeline_debug_mode_multi_epoch_map_pyfunc() teardown_function() diff --git a/tests/ut/python/dataset/test_pipeline_debug_mode_cifar.py b/tests/ut/python/dataset/test_pipeline_debug_mode_cifar.py index 562328da815..e8917886091 100644 --- a/tests/ut/python/dataset/test_pipeline_debug_mode_cifar.py +++ b/tests/ut/python/dataset/test_pipeline_debug_mode_cifar.py @@ -611,6 +611,40 @@ def test_cifar100ops(): assert "Input count is not within the required interval of" in str(error_info.value) +def test_pipeline_debug_mode_multi_epoch_cifar10(): + """ + Feature: Pipeline debug mode. + Description: Test creating tuple iterator in cifar10 dataset with multi epochs. + Expectation: Output is equal to the expected output + """ + logger.info("test_pipeline_debug_mode_multi_epoch_cifar10") + data_dir_10 = "../data/dataset/testCifar10Data" + num_repeat = 2 + batch_size = 32 + limit_dataset = 100 + # apply dataset operations + data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset) + data1 = data1.repeat(num_repeat) + data1 = data1.batch(batch_size, True) + num_epoch = 5 + iter1 = data1.create_tuple_iterator(num_epochs=num_epoch) + epoch_count = 0 + sample_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + # in this example, each row has columns "image" and "label" + row_count += 1 + assert row_count == int(limit_dataset * num_repeat / batch_size) + logger.debug("row_count: ", row_count) + epoch_count += 1 + sample_count += row_count + assert epoch_count == num_epoch + logger.debug("total epochs: ", epoch_count) + assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch + logger.debug("total sample: ", sample_count) + + if __name__ == '__main__': setup_function() test_cifar10_content_check() @@ -629,4 +663,5 @@ if __name__ == '__main__': test_cifar10_with_chained_sampler_get_dataset_size() test_cifar10_pk_sampler_get_dataset_size() test_cifar100ops() + test_pipeline_debug_mode_multi_epoch_cifar10() teardown_function()