From b5b333726f40813103147d78ad8e9add8e1fba23 Mon Sep 17 00:00:00 2001 From: Lixia Chen Date: Tue, 4 Aug 2020 10:10:18 -0400 Subject: [PATCH] Leaf ops do self-reset. --- .../datasetops/bucket_batch_by_length_op.cc | 10 -- .../datasetops/bucket_batch_by_length_op.h | 4 - .../engine/datasetops/cache_base_op.cc | 22 ++- .../dataset/engine/datasetops/cache_base_op.h | 3 - .../dataset/engine/datasetops/dataset_op.cc | 1 + .../engine/datasetops/epoch_ctrl_op.cc | 17 +-- .../dataset/engine/datasetops/filter_op.cc | 2 +- .../dataset/engine/datasetops/map_op/map_op.h | 6 - .../dataset/engine/datasetops/parallel_op.cc | 14 +- .../dataset/engine/datasetops/parallel_op.h | 20 ++- .../dataset/engine/datasetops/repeat_op.cc | 34 +---- .../dataset/engine/datasetops/repeat_op.h | 9 -- .../engine/datasetops/source/album_op.cc | 26 +++- .../engine/datasetops/source/album_op.h | 3 - .../engine/datasetops/source/celeba_op.cc | 26 +++- .../engine/datasetops/source/celeba_op.h | 2 - .../engine/datasetops/source/cifar_op.cc | 26 +++- .../engine/datasetops/source/cifar_op.h | 4 +- .../engine/datasetops/source/clue_op.cc | 3 + .../engine/datasetops/source/clue_op.h | 1 - .../engine/datasetops/source/coco_op.cc | 24 ++- .../engine/datasetops/source/coco_op.h | 3 - .../engine/datasetops/source/csv_op.cc | 3 + .../engine/datasetops/source/generator_op.cc | 12 +- .../engine/datasetops/source/generator_op.h | 2 - .../datasetops/source/image_folder_op.cc | 26 +++- .../datasetops/source/image_folder_op.h | 3 - .../engine/datasetops/source/io_block.h | 9 +- .../engine/datasetops/source/manifest_op.cc | 24 ++- .../engine/datasetops/source/manifest_op.h | 3 - .../engine/datasetops/source/mindrecord_op.cc | 51 ++++--- .../engine/datasetops/source/mindrecord_op.h | 3 - .../engine/datasetops/source/mnist_op.cc | 24 ++- .../engine/datasetops/source/mnist_op.h | 3 - .../datasetops/source/random_data_op.cc | 11 +- .../engine/datasetops/source/text_file_op.cc | 3 + .../engine/datasetops/source/text_file_op.h | 1 - .../engine/datasetops/source/tf_reader_op.cc | 3 + .../engine/datasetops/source/voc_op.cc | 24 ++- .../dataset/engine/datasetops/source/voc_op.h | 3 - .../dataset/engine/opt/post/repeat_pass.cc | 137 +++--------------- .../dataset/engine/opt/post/repeat_pass.h | 24 +-- .../data/dataset/testPyfuncMap/pyfuncmap.py | 10 +- .../dataset/test_bucket_batch_by_length.py | 29 ++++ .../python/dataset/test_datasets_cifarop.py | 2 +- .../ut/python/dataset/test_datasets_mnist.py | 2 +- tests/ut/python/dataset/test_filterop.py | 2 +- tests/ut/python/dataset/test_paddeddataset.py | 2 +- 48 files changed, 340 insertions(+), 336 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc index c23537c6f3d..5963619de97 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -218,16 +218,6 @@ Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t ba return Status::OK(); } -Status BucketBatchByLengthOp::Reset() { - batch_count_ = 0; - - for (int i = 0; i < buckets_.size(); i++) { - buckets_[i] = std::make_unique(); - } - - return Status::OK(); -} - // Computing the assignment of the column name map and check compute input columns. Status BucketBatchByLengthOp::ComputeColMap() { RETURN_IF_NOT_OK(DatasetOp::ComputeColMap()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h index 21fc55e2635..3fd446322b1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -126,10 +126,6 @@ class BucketBatchByLengthOp : public PipelineOp { // @return Status - The error code returned Status operator()() override; - // Function that is called by ResetOp at the end of every epoch - // @return Status - The error code returned - Status Reset() override; - private: Status ObtainElementLength(int32_t *out_element_length, TensorRow element); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index 554eb2b19b3..533de9521d5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -42,8 +42,7 @@ Status CacheBase::Reset() { RETURN_IF_NOT_OK(sampler_->ResetSampler()); } // Wake up the workers to get them going again in a new epoch - MS_LOG(DEBUG) << Name() << " resetting."; - epoch_sync_.Set(); + MS_LOG(DEBUG) << Name() << " performing a self-reset."; return Status::OK(); } CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, @@ -72,7 +71,6 @@ Status CacheBase::FetchSamplesToWorkers() { // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them // to the WorkerEntry. do { - epoch_sync_.Clear(); if (AllowCacheMiss() && wait_cnt > 0) { MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_ << " Total number of rows : " << row_cnt_; @@ -112,11 +110,17 @@ Status CacheBase::FetchSamplesToWorkers() { // If repeat but the not last repeat, wait for reset. if (!IsLastIteration()) { MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; - RETURN_IF_NOT_OK(epoch_sync_.Wait()); } else { // We can break out from the loop. break; } + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset()); UpdateRepeatAndEpochCounter(); } while (true); // Flow the eof before exit @@ -142,7 +146,13 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { std::unique_ptr blk; do { RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); - if (blk->eof()) { + if (blk->wait()) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (blk->eof()) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); } else if (blk->eoe()) { if (AllowCacheMiss()) { @@ -186,7 +196,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { } Status CacheBase::RegisterResources() { - RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h index 2225d4f3350..333d4e190a7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -26,7 +26,6 @@ #include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "minddata/dataset/util/queue.h" @@ -88,7 +87,6 @@ class CacheBase : public ParallelOp { int64_t row_cnt_; std::atomic num_cache_miss_; std::shared_ptr cache_client_; - WaitPost epoch_sync_; int32_t rows_per_buffer_; Connector> keys_miss_; QueueMap prefetch_; @@ -110,7 +108,6 @@ class CacheBase : public ParallelOp { private: constexpr static int32_t connector_capacity_ = 1024; int32_t prefetch_size_; - QueueList> io_block_queues_; QueueList> prefetch_queues_; std::unique_ptr>> sampler_queue_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 1b877779274..c9357c9ff0a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -434,6 +434,7 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { void DatasetOp::UpdateRepeatAndEpochCounter() { op_current_repeats_++; if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; + MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc index 1343fd46086..5ec26311bff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc @@ -51,15 +51,7 @@ void EpochCtrlOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_ - << "\nLeaf Nodes in execution path:"; - if (!eoe_ops_.empty()) { - for (size_t i = 0; i < eoe_ops_.size(); i++) { - out << "\n Operator: " << eoe_ops_[i]->id(); - } - } else { - out << " None."; - } + out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << num_repeats_; out << "\n\n"; } } @@ -94,13 +86,6 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) { // This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it. state_ = OpState::kDeOpIdle; - if (repeat_count_ != num_repeats_) { - for (auto &eoe_op : eoe_ops_) { - MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id(); - RETURN_IF_NOT_OK(eoe_op->Reset()); - } - } - return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc index dd672468b11..3e931bf2457 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -123,7 +123,6 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); if (in_buffer->eoe()) { filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); - UpdateRepeatAndEpochCounter(); continue; } else if (in_buffer->eof()) { filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); @@ -200,6 +199,7 @@ Status FilterOp::Collector() { RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || in_pair.second == filterCtrl::kFilterEoe) { + if (in_pair.second == filterCtrl::kFilterEoe) UpdateRepeatAndEpochCounter(); uint32_t out_task_id = out_id_cnt % num_workers_; RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); out_id_cnt++; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h index de99a2587ea..6384b7ec9cd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h @@ -228,12 +228,6 @@ class MapOp : public ParallelOp { // Indices of the columns to process. std::vector to_process_indices_; - // Wait post used to perform the pausing logic in MapOp - WaitPost wait_for_workers_post_; - - // Count number of workers that have signaled master - std::atomic_int num_workers_paused_; - // Private function for worker/thread to loop continuously. It comprises the main // logic of MapOp: getting the data from previous Op, validating user specified column names, // applying a list of TensorOps to each of the data, process the results and then diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc index abb827aea85..4463b72924e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc @@ -31,7 +31,9 @@ ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shar num_workers_(num_workers), num_producers_(num_workers), worker_connector_size_(1), - worker_connector_(nullptr) {} + worker_connector_(nullptr), + num_workers_paused_(0), + epoch_sync_flag_(false) {} // Creates the internal worker connector for the parallel op if the derived class wants to use it Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { @@ -82,5 +84,15 @@ Status ParallelOp::RegisterWorkerConnectors() { } return Status::OK(); } + +Status ParallelOp::WaitForWorkers() { + num_workers_paused_ = 0; + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK(io_block_queues_[i]->Add(std::make_unique(IOBlock::kDeIoBlockFlagWait))); + } + RETURN_IF_NOT_OK(wait_for_workers_post_.Wait()); + wait_for_workers_post_.Clear(); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h index 8d7ba6302ae..e09ef52e2cd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h @@ -21,6 +21,7 @@ #include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/util/status.h" namespace mindspore { @@ -117,10 +118,27 @@ class ParallelOp : public DatasetOp { // @return Status - The error code return virtual Status WorkerEntry(int32_t workerId) = 0; + /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp + /// The expected behavior is this, when this function is invoked, this function will block until all the workers + /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. + /// They would automatically wait on the QueueList when they are done. + /// \return Status + Status WaitForWorkers() override; + + // Wait post used to perform the pausing logic + WaitPost wait_for_workers_post_; + + // Count number of workers that have signaled master + std::atomic_int num_workers_paused_; + + // Whether or not to sync worker threads at the end of each epoch + bool epoch_sync_flag_; + int32_t num_workers_; // The number of worker threads int32_t num_producers_; // The number of threads pushing to the out_connector_ int32_t worker_connector_size_; - std::unique_ptr worker_connector_; // The internal connector for worker threads + std::unique_ptr worker_connector_; // The internal connector for worker threads + QueueList> io_block_queues_; // queues of IOBlocks }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc index 0d512b19541..2e50c4d992a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -62,15 +62,7 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info PipelineOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_ - << "\nLeaf Nodes in execution path:"; - if (!eoe_ops_.empty()) { - for (size_t i = 0; i < eoe_ops_.size(); i++) { - out << "\n Operator: " << eoe_ops_[i]->id(); - } - } else { - out << " None."; - } + out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << num_repeats_; out << "\n\n"; } } @@ -108,7 +100,6 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t wo // Base-class override for handling cases when an eoe is received. Status RepeatOp::EoeReceived(int32_t worker_id) { UpdateRepeatAndEpochCounter(); - repeat_count_++; MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; @@ -116,15 +107,9 @@ Status RepeatOp::EoeReceived(int32_t worker_id) { if (repeat_count_ == num_repeats_) { repeat_count_ = 0; state_ = OpState::kDeOpIdle; - return Status::OK(); + } else { + state_ = OpState::kDeOpRunning; } - - // Invoke a reset against the eoe nodes only. - for (auto &eoe_op : eoe_ops_) { - MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id(); - RETURN_IF_NOT_OK(eoe_op->Reset()); - } - return Status::OK(); } @@ -153,19 +138,6 @@ int32_t RepeatOp::num_consumers() const { } } -// Drive reset actions if needed -Status RepeatOp::Reset() { - // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. - // In that case, we now have to bounce the reset down to our own eoe ops. - MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset."; - for (auto &eoe_op : eoe_ops_) { - MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id(); - RETURN_IF_NOT_OK(eoe_op->Reset()); - } - state_ = OpState::kDeOpRunning; - return Status::OK(); -} - int32_t RepeatOp::num_producers() const { if (child_.empty() || child_[0] == nullptr) { MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h index bdd49535418..d2af23976e4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -101,10 +101,6 @@ class RepeatOp : public PipelineOp { // @param worker_id - The worker id Status EofReceived(int32_t worker_id) override; - /// \brief reset Op - /// \@return Status - The error code return - Status Reset() override; - // Base-class override. Return the number of workers in the first parent. // @param workerId - The worker id int32_t num_consumers() const override; @@ -133,10 +129,6 @@ class RepeatOp : public PipelineOp { /// \return The number of repeats that the user requested int32_t num_repeats() { return num_repeats_; } - // \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes - // \param[in] eoe_op The input leaf/eoe operator to add to the list - void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } - protected: // The number of repeats that the user requested. // Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class. @@ -147,7 +139,6 @@ class RepeatOp : public PipelineOp { // Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class // because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats. int32_t repeat_count_; - std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc index efc6acece78..559cb1457f1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.cc @@ -161,11 +161,19 @@ Status AlbumOp::operator()() { io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); } return Status::OK(); - } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset + } else { // not the last repeat. RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -180,7 +188,13 @@ Status AlbumOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -468,9 +482,9 @@ void AlbumOp::Print(std::ostream &out, bool show_all) const { // Reset Sampler and wakeup Master thread (functor) Status AlbumOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done return Status::OK(); } @@ -486,7 +500,7 @@ Status AlbumOp::LaunchThreadsAndInitOp() { } // registers QueueList and individual Queues for interrupt services RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); // launch main workers that load DataBuffers by reading all images RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&AlbumOp::WorkerEntry, this, std::placeholders::_1))); TaskManager::FindMe()->Post(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h index 3ef4e7bf894..d16a3eb2d65 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h @@ -30,7 +30,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/queue.h" @@ -289,9 +288,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { int64_t buf_cnt_; int64_t sampler_ind_; int64_t dirname_offset_; - WaitPost wp_; std::vector image_rows_; - QueueList> io_block_queues_; // queues of IOBlocks }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index f4a2f24cb44..8ff2cf5bf62 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -94,7 +94,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() { RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); @@ -311,11 +311,19 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); } return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + } else { // not the last repeat. RETURN_IF_NOT_OK( io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); } UpdateRepeatAndEpochCounter(); @@ -328,7 +336,13 @@ Status CelebAOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -410,8 +424,8 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const { // Reset Sampler and wakeup Master thread (functor) Status CelebAOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); - wp_.Set(); // wake up master thread after reset is done return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h index d0cf412a5ca..4d50bb16fd3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -229,8 +229,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { std::unique_ptr data_schema_; std::unique_ptr>> attr_info_queue_; int64_t num_rows_in_attr_file_; // rows number specified in attr file - QueueList> io_block_queues_; - WaitPost wp_; std::vector>> image_labels_vec_; std::string usage_; std::ifstream partition_file_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index c36b9ca150b..6a32f2c17db 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -140,11 +140,19 @@ Status CifarOp::operator()() { io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); } return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + } else { // not the last repeat. RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -156,7 +164,7 @@ Status CifarOp::LaunchThreadsAndInitOp() { RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); } RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK( tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); @@ -175,7 +183,13 @@ Status CifarOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -243,9 +257,9 @@ void CifarOp::Print(std::ostream &out, bool show_all) const { // Reset Sampler and wakeup Master thread (functor) Status CifarOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h index 39f3d76f41b..882b688fd50 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -26,7 +26,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/queue.h" @@ -233,11 +232,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp { int32_t rows_per_buffer_; std::string folder_path_; std::unique_ptr data_schema_; + int64_t row_cnt_; int64_t buf_cnt_; const std::string usage_; // can only be either "train" or "test" - WaitPost wp_; - QueueList> io_block_queues_; std::unique_ptr>> cifar_raw_data_block_; std::vector cifar_files_; std::vector, std::vector>> cifar_image_label_pairs_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 5fd861ebc8f..72983c73b46 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -119,6 +119,7 @@ Status ClueOp::Init() { } Status ClueOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; load_jagged_connector_ = true; load_io_block_queue_ = true; @@ -274,6 +275,8 @@ Status ClueOp::operator()() { } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; + // Self-reset to start a new iteration + RETURN_IF_NOT_OK(Reset()); } UpdateRepeatAndEpochCounter(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h index eea6b5aa7ec..47e50044e2c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -25,7 +25,6 @@ #include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index 2c48a227f9c..f3c77bce450 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -185,8 +185,16 @@ Status CocoOp::operator()() { } else { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -208,9 +216,9 @@ void CocoOp::Print(std::ostream &out, bool show_all) const { } Status CocoOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); return Status::OK(); } @@ -377,7 +385,13 @@ Status CocoOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -609,7 +623,7 @@ Status CocoOp::LaunchThreadsAndInitOp() { RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); } RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(this->ParseAnnotationIds()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h index ff18c6b8c2c..d7f923ff3db 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -27,7 +27,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/util/path.h" @@ -327,10 +326,8 @@ class CocoOp : public ParallelOp, public RandomAccessOp { std::shared_ptr sampler_; std::unique_ptr data_schema_; - WaitPost wp_; std::vector image_ids_; std::map image_index_; - QueueList> io_block_queues_; std::vector>> label_index_; std::map coordinate_map_; std::map> simple_item_map_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 67dd7cacf54..95983ca0693 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -479,6 +479,7 @@ Status CsvOp::CsvParser::InitCsvParser() { } Status CsvOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; load_jagged_connector_ = true; load_io_block_queue_ = true; @@ -572,6 +573,8 @@ Status CsvOp::operator()() { } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; + // Self-reset to start a new iteration + RETURN_IF_NOT_OK(Reset()); } UpdateRepeatAndEpochCounter(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 0dee92730f3..2e9ddddc2e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -186,7 +186,6 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { Status GeneratorOp::operator()() { // Handshake with TaskManager to synchronize thread creation TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); std::unique_ptr fetched_buffer; bool eof = false; while (!eof) { @@ -228,12 +227,8 @@ Status GeneratorOp::operator()() { MS_LOG(DEBUG) << "Generator operator main execution loop complete."; eof = true; } else { - // Waiting for repeatOp to start new epoch - // If Reset() is called first by repeat op, this wait() will return right away. - // If Reset() is not called yet, this wait() will block until reset. - RETURN_IF_NOT_OK(wp_.Wait()); - // Clear the status of the wait post - wp_.Clear(); + // Self-reset to start a new iteration + RETURN_IF_NOT_OK(Reset()); } UpdateRepeatAndEpochCounter(); } @@ -243,9 +238,8 @@ Status GeneratorOp::operator()() { Status GeneratorOp::Reset() { // Reset Op state + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(this->Init()); - // Wake up master thread - wp_.Set(); return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index 1d7f2b97f36..175d1ce680f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -144,8 +144,6 @@ class GeneratorOp : public PipelineOp { py::object generator_; int32_t buffer_id_; - WaitPost wp_; - Status Init(); void Dealloc() noexcept; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 32fa5eceebe..43121d2a292 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -164,11 +164,19 @@ Status ImageFolderOp::operator()() { io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); } return Status::OK(); - } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset + } else { // not the last repeat. RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -183,7 +191,13 @@ Status ImageFolderOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -247,9 +261,9 @@ void ImageFolderOp::Print(std::ostream &out, bool show_all) const { // Reset Sampler and wakeup Master thread (functor) Status ImageFolderOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done return Status::OK(); } @@ -365,7 +379,7 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() { RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); // The following code launch 3 threads group // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h index 219a4e53f9c..53e7357e1ac 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -29,7 +29,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/util/path.h" @@ -263,9 +262,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { int64_t buf_cnt_; int64_t sampler_ind_; int64_t dirname_offset_; - WaitPost wp_; std::vector image_label_pairs_; - QueueList> io_block_queues_; // queues of IOBlocks std::unique_ptr> folder_name_queue_; std::unique_ptr> image_name_queue_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h index 86ae3c9c563..1f3c0ca19da 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h @@ -33,8 +33,9 @@ class IOBlock { public: enum IOBlockFlags : uint32_t { kDeIoBlockNone = 0, - kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch - kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program + kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch + kDeIoBlockFlagEof = 1u << 1, // end of IOBlocks for entire program + kDeIoBlockFlagWait = 1u << 2 // control signal for workers to suspend operations }; // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. @@ -73,6 +74,10 @@ class IOBlock { // @return T/F if the IOBlock is eof bool eof() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEof); } + // Does this block have the wait flag turned on? + // @return T/F is the IOBlock is wait + bool wait() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagWait); } + // Adds a key to this block // @param key - The key to add to this block void AddKey(int64_t key) { index_keys_.push_back(key); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 0f36186e7b5..6b675d0112b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -127,8 +127,16 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { } else { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -140,7 +148,7 @@ Status ManifestOp::LaunchThreadsAndInitOp() { RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); } RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK( tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); @@ -159,7 +167,13 @@ Status ManifestOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -235,9 +249,9 @@ void ManifestOp::Print(std::ostream &out, bool show_all) const { // Reset Sampler and wakeup Master thread (functor) Status ManifestOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h index dbe7a28e94f..2a022868d28 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -26,7 +26,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/util/queue.h" @@ -242,8 +241,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { std::string usage_; int64_t buf_cnt_; - WaitPost wp_; - QueueList> io_block_queues_; std::map label_index_; std::vector>> image_labelname_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index 15fe89b69b5..6ba6a132f44 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -129,7 +129,9 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf num_padded_(num_padded), sample_json_(sample_json), sample_bytes_(sample_bytes) { - io_blk_queues_.Init(num_workers_, op_connector_queue_size); + io_block_queues_.Init(num_workers_, op_connector_queue_size); + epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all + // tasks are consumed by the worker threads would cause problem. } // Private helper method to encapsulate some common construction/reset tasks @@ -219,18 +221,27 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { Status MindRecordOp::WorkerEntry(int32_t worker_id) { TaskManager::FindMe()->Post(); std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { + if (io_block->wait()) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + continue; + } if (io_block->eoe()) { RETURN_IF_NOT_OK( out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); continue; } if (io_block->eof()) { RETURN_IF_NOT_OK( out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); continue; } @@ -255,7 +266,7 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) { } RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); } RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker."); } @@ -377,27 +388,31 @@ Status MindRecordOp::operator()() { while (true) { // each iterator is 1 epoch for (int32_t i = 0; i < buffers_needed_; ++i) { std::vector keys(1, i); - RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); } if (IsLastIteration()) { RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( + RETURN_IF_NOT_OK(io_block_queues_[i]->Add( std::move(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone)))); } return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + } else { RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - - // reset our buffer count and go to loop again. - RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); - shard_reader_wait_post_.Clear(); + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) RETURN_IF_NOT_OK(Reset()); UpdateRepeatAndEpochCounter(); } } @@ -406,10 +421,10 @@ Status MindRecordOp::operator()() { // info from it's previous execution and then initializes itself so that it can be executed // again. Status MindRecordOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. shard_reader_->ShuffleTask(); - shard_reader_wait_post_.Set(); return Status::OK(); } @@ -419,8 +434,8 @@ Status MindRecordOp::LaunchThreadAndInitOp() { RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); } - RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); if (shard_reader_->Launch(true) == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h index 939e48b6166..dae29f5541a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -29,7 +29,6 @@ #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/status.h" #include "minddata/mindrecord/include/shard_column.h" @@ -247,8 +246,6 @@ class MindRecordOp : public ParallelOp { std::vector columns_blob_index_; // Blob Columns to load from dataset std::unique_ptr shard_reader_; - WaitPost shard_reader_wait_post_; - QueueList> io_blk_queues_; std::mutex ended_worker_mutex_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index 24d8635eb5a..1e282896bf2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -135,8 +135,16 @@ Status MnistOp::operator()() { } else { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -150,7 +158,13 @@ Status MnistOp::WorkerEntry(int32_t worker_id) { std::unique_ptr iOBlock; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); while (iOBlock != nullptr) { - if (iOBlock->eoe() == true) { + if (iOBlock->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (iOBlock->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (iOBlock->eof() == true) { @@ -208,9 +222,9 @@ void MnistOp::Print(std::ostream &out, bool show_all) const { // Reset Sampler and wakeup Master thread (functor) Status MnistOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done return Status::OK(); } @@ -401,7 +415,7 @@ Status MnistOp::LaunchThreadsAndInitOp() { RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); } RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(this->WalkAllFiles()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index 91f4c2feb00..db9a9587b01 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -27,7 +27,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/queue.h" @@ -245,7 +244,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { int64_t buf_cnt_; int64_t row_cnt_; - WaitPost wp_; std::string folder_path_; // directory of image folder int32_t rows_per_buffer_; const std::string usage_; // can only be either "train" or "test" @@ -253,7 +251,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { std::vector image_label_pairs_; std::vector image_names_; std::vector label_names_; - QueueList> io_block_queues_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 7f0c56dbeb8..381e4ed9bb9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -239,10 +239,15 @@ Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { } } - // Wait for the reset to wake us up if we're not quitting if (!(*quitting)) { MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; - RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); + if (last_guy_in) { + // If we are the last worker, do reset to wake other workers up + RETURN_IF_NOT_OK(Reset()); + } else { + // If we are not the last worker, wait for the reset + RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); + } prev = guys_out_.fetch_add(1); bool last_guy_out = (prev + 1) == num_workers_; // Last guy out will clear the wait post and set the row counts @@ -365,7 +370,7 @@ Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { // info from it's previous execution and then initializes itself so that it can be executed // again. Status RandomDataOp::Reset() { - MS_LOG(INFO) << "RandomDataOp resetting."; + MS_LOG(DEBUG) << Name() << " performing a self-reset."; // Ensure all guys are in the waitpost if (guys_in_ != num_workers_) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index 3d9b79a7dbf..7bc189eaf3a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -136,6 +136,7 @@ Status TextFileOp::Init() { } Status TextFileOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; load_jagged_connector_ = true; load_io_block_queue_ = true; @@ -432,6 +433,8 @@ Status TextFileOp::operator()() { } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; + // Self-reset to start a new iteration + RETURN_IF_NOT_OK(Reset()); } UpdateRepeatAndEpochCounter(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h index 9dfb4ac2ae6..498dea08a79 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -27,7 +27,6 @@ #include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/wait_post.h" #include "minddata/dataset/engine/jagged_connector.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index ffc2a97ef7c..b112eb4138a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -317,6 +317,8 @@ Status TFReaderOp::operator()() { } else { jagged_buffer_connector_->DoReset(); buffer_id = 0; + // Self-reset to start a new iteration + RETURN_IF_NOT_OK(Reset()); } UpdateRepeatAndEpochCounter(); } @@ -709,6 +711,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table // Overrides base class reset method. Cleans up any state info from it's previous execution and // reinitializes itself so that it can be executed again, as if it was just created. Status TFReaderOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true load_jagged_connector_ = true; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index a12bddb57be..179b92d5bf9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -164,8 +164,16 @@ Status VOCOp::operator()() { } else { RETURN_IF_NOT_OK( io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); - wp_.Clear(); + } + + if (epoch_sync_flag_) { + // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for + // the current epoch. + RETURN_IF_NOT_OK(WaitForWorkers()); + } + // If not the last repeat, self-reset and go to loop again. + if (!IsLastIteration()) { + RETURN_IF_NOT_OK(Reset()); RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); } UpdateRepeatAndEpochCounter(); @@ -187,9 +195,9 @@ void VOCOp::Print(std::ostream &out, bool show_all) const { } Status VOCOp::Reset() { + MS_LOG(DEBUG) << Name() << " performing a self-reset."; RETURN_IF_NOT_OK(sampler_->ResetSampler()); row_cnt_ = 0; - wp_.Set(); return Status::OK(); } @@ -235,7 +243,13 @@ Status VOCOp::WorkerEntry(int32_t worker_id) { std::unique_ptr io_block; RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); while (io_block != nullptr) { - if (io_block->eoe() == true) { + if (io_block->wait() == true) { + // Sync io_block is a signal that master thread wants us to pause and sync with other workers. + // The last guy who comes to this sync point should reset the counter and wake up the master thread. + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + } else if (io_block->eoe() == true) { RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); buffer_id = worker_id; } else if (io_block->eof() == true) { @@ -367,7 +381,7 @@ Status VOCOp::LaunchThreadsAndInitOp() { RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); } RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(this->ParseImageIds()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h index de824f60500..dca7a8e793d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -26,7 +26,6 @@ #include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" -#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/util/path.h" @@ -283,9 +282,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { int32_t rows_per_buffer_; std::unique_ptr data_schema_; - WaitPost wp_; std::vector image_ids_; - QueueList> io_block_queues_; std::map class_index_; std::map label_index_; std::map annotation_map_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc index d737a0fa1bc..83bf51404e9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -27,25 +27,10 @@ namespace mindspore { namespace dataset { RepeatPass::RepeatPass() - : is_repeated_(false), - nested_repeats_(0), - num_repeats_(1), - num_epochs_(1), - is_merge_(false), - is_cached_(false), - cache_lookup_(nullptr) {} + : num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {} // Identifies the subtree below this node as being in a repeated path of the tree. Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Create a new stack for eoe operators and push onto our stack of stacks. - std::unique_ptr new_stack = std::make_unique(); - eoe_op_stacks_.push(std::move(new_stack)); - // If we are already repeated, then this is a nested repeat. - if (is_repeated_) { - nested_repeats_++; - } - is_repeated_ = true; - // If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_. // Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely. if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) { @@ -73,9 +58,7 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modifie // that RepeatOp does. However, epoch control is actually simpler because it can // only exist as the root node so it doesn't need all the nested code. // Create a new stack for eoe operators and push onto our stack of stacks. - std::unique_ptr new_stack = std::make_unique(); - eoe_op_stacks_.push(std::move(new_stack)); - is_repeated_ = true; + // Get the total number of epochs from the EpochCtrlOp parameter num_epochs_ = node->num_repeats(); // Every node below this EpochCtrlOp should be repeated for num_epochs_ times. @@ -102,44 +85,16 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { // Hooks up any identified eoe nodes under this repeat. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking - std::shared_ptr leaf_op = PopFromEOEOpStack(); - - while (leaf_op != nullptr) { - node->AddToEoeList(leaf_op); - leaf_op = PopFromEOEOpStack(); - } - - // At this point, we are done with the save area stack. It's a unique pointer to an empty stack - // at this time, so we can pop it to get rid of it. - op_stack *current_stack = eoe_op_stacks_.top().get(); - if (!current_stack->empty()) { - RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!"); - } - eoe_op_stacks_.pop(); - // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up - // and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed - // from the save area, because the merge op above us may also take action on it later for a different - // case when there is no repeat in the merge leg. + // and set its total repeats. It is important that the op is removed from the save area, + // because the merge op above us may also take action on it later for a different case when + // there is no repeat in the merge leg. if (is_merge_ && cache_lookup_) { cache_lookup_->set_total_repeats(num_repeats_); cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); - node->AddToEoeList(std::move(cache_lookup_)); + cache_lookup_.reset(); } - // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. - // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. - if (nested_repeats_ > 0) { - AddToEOEOpStack(node); - nested_repeats_--; - } else { - // If we are not nested, or we were the top-most repeat, now we clear the flag - if (nested_repeats_ != 0) { - RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!"); - } - is_repeated_ = false; - } if (is_cached_) { AddToCachedOpStack(node); } @@ -155,13 +110,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Hooks up any identified eoe nodes under this repeat. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // Pop the leaf ops from the save-area stack and add them to the eoe node tracking - std::shared_ptr leaf_op = PopFromEOEOpStack(); - while (leaf_op != nullptr) { - node->AddToEoeList(leaf_op); - leaf_op = PopFromEOEOpStack(); - } - is_repeated_ = false; node->set_total_repeats(num_repeats_); node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); // We finish the walk of this EpochCtrl's descendent nodes. @@ -172,31 +120,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) // CacheOp removes previous leaf ops and replaces them with itself Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { is_cached_ = false; - if (is_repeated_) { - // if we are a cache within a repeat path of the tree, then there will be - // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the - // repeat or epoch ctrl operators can work with them for repeat activity during runtime. - // However, since a cache is present: - // - unflag those ops as being repeated ops - // - remove them from the eoe op stack so that repeat op above in the tree won't know about them - // - add ourself (the cache op), as an eoe op - // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead - // the repeating behaviours shall be invoked against the cache op. - std::shared_ptr leaf_op = PopFromEOEOpStack(); - while (leaf_op != nullptr) { - leaf_op = PopFromEOEOpStack(); - } - AddToEOEOpStack(std::static_pointer_cast(node)); - // adjust the total epochs and total repeats for ops under this cache op - std::shared_ptr cached_op = PopFromCachedOpStack(); - while (cached_op != nullptr) { - int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; - cached_op->set_total_repeats(cached_op_total_repeats); - // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 - cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); - cached_op = PopFromCachedOpStack(); - } + // if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops. + // So that those cached nodes become 1-time use (up to eoe), never repeated. Instead + // the repeating behaviours shall be invoked against the cache op. + std::shared_ptr cached_op = PopFromCachedOpStack(); + while (cached_op != nullptr) { + int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_; + cached_op->set_total_repeats(cached_op_total_repeats); + // Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1 + cached_op->set_num_repeats_per_epoch(cached_op_total_repeats); + cached_op = PopFromCachedOpStack(); } node->set_total_repeats(num_repeats_); @@ -207,13 +141,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up // for use with a controlling repeat above it. Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // If we are in a repeat path, then set our repeated flag - if (is_repeated_) { - // if we are a leaf node then save ourself in a stack for the repeat operator above us - if (node->IsLeaf()) { - AddToEOEOpStack(node); - } - } + // If we are under a cache op, then save ourselves to the cached op stack. if (is_cached_) { AddToCachedOpStack(node); } @@ -225,15 +153,11 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { // Turns off the tracking for operations under merge op Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // Setting the flag is needed since we didn't call the base class DatasetOp version - if (is_repeated_) { - // If there was not any repeat in the merge cache miss leg, then the cache_lookup - // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack - if (cache_lookup_) { - cache_lookup_->set_total_repeats(num_repeats_); - node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); - AddToEOEOpStack(std::move(cache_lookup_)); - } + // If there was not any repeat in the merge cache miss leg, then the cache_lookup + // would not have been consumed yet. In that case, we need to set its total repeats for it. + if (cache_lookup_) { + cache_lookup_->set_total_repeats(num_repeats_); + node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); } node->set_total_repeats(num_repeats_); node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); @@ -266,23 +190,6 @@ Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified return Status::OK(); } -// Adds an operator to the eoe operator stack save area -void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { - op_stack *current_stack = eoe_op_stacks_.top().get(); - current_stack->push(dataset_op); -} - -// Pops an operator from the eoe operator stack save area -std::shared_ptr RepeatPass::PopFromEOEOpStack() { - std::shared_ptr top_op = nullptr; - op_stack *current_stack = eoe_op_stacks_.top().get(); - if (current_stack != nullptr && !current_stack->empty()) { - top_op = current_stack->top(); - current_stack->pop(); - } - return top_op; -} - // Adds an operator to the cached operator stack save area void RepeatPass::AddToCachedOpStack(std::shared_ptr dataset_op) { cached_op_stacks_.push(dataset_op); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h index 082f8e2af3a..4345ecc6f61 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -106,15 +106,6 @@ class RepeatPass : public NodePass { Status RunOnNode(std::shared_ptr node, bool *modified) override; private: - /// \brief Adds an operator to the eoe operator stack save area - /// \param op - The dataset op to work add to eoe stack - /// \return Status - The error code return - void AddToEOEOpStack(std::shared_ptr dataset_op); - - /// \brief Pops an operator from the eoe operator stack save area - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromEOEOpStack(); - /// \brief Adds an operator to the cached operator stack save area /// \param op - The dataset op to work add to cached stack /// \return Status - The error code return @@ -124,15 +115,12 @@ class RepeatPass : public NodePass { /// \return shared_ptr to the popped operator std::shared_ptr PopFromCachedOpStack(); - bool is_repeated_; // T/F if we are processing under a repeat - bool is_merge_; // T/F if we are processing under a cache merge op - bool is_cached_; // T/F is we are processing under a cache op - int32_t nested_repeats_; // A counter for nested repeats - int32_t num_repeats_; // A multiplier to the total number of repeats - int32_t num_epochs_; // To save the total number of epochs - std::stack> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting) - op_stack cached_op_stacks_; // A save area for ops under a cache op - std::shared_ptr cache_lookup_; // A save area for a cache lookup op + bool is_merge_; // T/F if we are processing under a cache merge op + bool is_cached_; // T/F is we are processing under a cache op + int32_t num_repeats_; // A multiplier to the total number of repeats + int32_t num_epochs_; // To save the total number of epochs + op_stack cached_op_stacks_; // A save area for ops under a cache op + std::shared_ptr cache_lookup_; // A save area for a cache lookup op }; } // namespace dataset } // namespace mindspore diff --git a/tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py b/tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py index 3b200da8700..c22eff972e3 100644 --- a/tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py +++ b/tests/ut/data/dataset/testPyfuncMap/pyfuncmap.py @@ -32,7 +32,7 @@ def test_case_0(): ds1 = ds1.map(operations=(lambda x: x + x), input_columns=col, output_columns="out") print("************** Output Tensor *****************") - for data in ds1.create_dict_iterator(): # each data is a dictionary + for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary # in this example, each dictionary has keys "image" and "label" print(data["out"]) print("************** Output Tensor *****************") @@ -52,7 +52,7 @@ def test_case_1(): ds1 = ds1.map(operations=(lambda x: (x, x + x)), input_columns=col, output_columns=["out0", "out1"]) print("************** Output Tensor *****************") - for data in ds1.create_dict_iterator(): # each data is a dictionary + for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary # in this example, each dictionary has keys "image" and "label" print("out0") print(data["out0"]) @@ -75,7 +75,7 @@ def test_case_2(): ds1 = ds1.map(operations=(lambda x, y: x + y), input_columns=col, output_columns="out") print("************** Output Tensor *****************") - for data in ds1.create_dict_iterator(): # each data is a dictionary + for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary # in this example, each dictionary has keys "image" and "label" print(data["out"]) @@ -97,7 +97,7 @@ def test_case_3(): output_columns=["out0", "out1", "out2"]) print("************** Output Tensor *****************") - for data in ds1.create_dict_iterator(): # each data is a dictionary + for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary # in this example, each dictionary has keys "image" and "label" print("out0") print(data["out0"]) @@ -123,7 +123,7 @@ def test_case_4(): output_columns=["out0", "out1", "out2"], num_parallel_workers=4) print("************** Output Tensor *****************") - for data in ds1.create_dict_iterator(): # each data is a dictionary + for data in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary # in this example, each dictionary has keys "image" and "label" print("out0") print(data["out0"]) diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index eb6aca92bf2..68a0964f00c 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -141,6 +141,34 @@ def test_bucket_batch_multi_bucket_no_padding(): assert output == expected_output +def test_bucket_batch_multi_bucket_no_padding_repeat(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [3, 3, 2, 2] + element_length_function = (lambda x: x[0] % 4) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + dataset = dataset.repeat(2) + + expected_output = [[[2], [6]], + [[3], [7]], + [[0], [4], [8]], + [[1], [5], [9]], + [[2], [6]], + [[3], [7]], + [[0], [4], [8]], + [[1], [5], [9]]] + + output = [] + for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + output.append(data["col1"].tolist()) + + assert output == expected_output + + def test_bucket_batch_multi_bucket_with_padding(): dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"]) @@ -471,6 +499,7 @@ def test_bucket_batch_invalid_column(): if __name__ == '__main__': test_bucket_batch_invalid_input() test_bucket_batch_multi_bucket_no_padding() + test_bucket_batch_multi_bucket_no_padding_repeat() test_bucket_batch_multi_bucket_with_padding() test_bucket_batch_single_bucket_no_padding() test_bucket_batch_single_bucket_with_padding() diff --git a/tests/ut/python/dataset/test_datasets_cifarop.py b/tests/ut/python/dataset/test_datasets_cifarop.py index a073be084b9..2eda0cfce33 100644 --- a/tests/ut/python/dataset/test_datasets_cifarop.py +++ b/tests/ut/python/dataset/test_datasets_cifarop.py @@ -406,7 +406,7 @@ def test_cifar_usage(): try: data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage) num_rows = 0 - for _ in data.create_dict_iterator(): + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): num_rows += 1 except (ValueError, TypeError, RuntimeError) as e: return str(e) diff --git a/tests/ut/python/dataset/test_datasets_mnist.py b/tests/ut/python/dataset/test_datasets_mnist.py index 4e2c48a3449..8d81bd7c724 100644 --- a/tests/ut/python/dataset/test_datasets_mnist.py +++ b/tests/ut/python/dataset/test_datasets_mnist.py @@ -240,7 +240,7 @@ def test_mnist_usage(): try: data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False) num_rows = 0 - for _ in data.create_dict_iterator(): + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): num_rows += 1 except (ValueError, TypeError, RuntimeError) as e: return str(e) diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py index b0a168aa9ef..f10d44ec609 100644 --- a/tests/ut/python/dataset/test_filterop.py +++ b/tests/ut/python/dataset/test_filterop.py @@ -424,7 +424,7 @@ def generator_big(maxid=20): # test with row_data_buffer > 1 def test_filter_by_generator_Partial(): - dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"]) + dataset = ds.GeneratorDataset(source=(lambda: generator_mc(99)), column_names=["col1", "col2"]) dataset_s = dataset.shuffle(4) dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index fd21ee58821..4dbc187447b 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -502,7 +502,7 @@ def test_celeba_padded(): data = data.repeat(2) count = 0 - for _ in data.create_dict_iterator(): + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): count = count + 1 assert count == 2