diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc index 1bb2af7c889..a356b04be44 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc @@ -25,7 +25,9 @@ namespace mindspore { namespace dataset { PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) { (void)py::class_>(*m, "TreeConsumer") - .def("Reset", [](TreeConsumer &self, int64_t step) { THROW_IF_ERROR(self.Reset(step)); }); + .def("Reset", [](TreeConsumer &self, int64_t step, uint64_t epoch) { + THROW_IF_ERROR(self.Reset(step, epoch)); + }); })); PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) { (void)py::class_>( diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index c9f09cc0e48..4dac024abb3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -348,7 +348,7 @@ Status ToDevice::Terminate() { return TreeConsumer::Terminate(); } -Status TreeConsumer::Reset(int64_t step) { +Status TreeConsumer::Reset(int64_t step, const int64_t epoch_num) { MS_LOG(INFO) << "Resetting TreeConsumer"; MS_LOG(INFO) << "Terminating pipeline with UUID:" << tree_adapter_->tree_->GetUniqueId(); @@ -374,7 +374,7 @@ Status TreeConsumer::Reset(int64_t step) { } #endif tree_adapter_ = std::make_unique(TreeAdapter::UsageFlag::kDeReset); - RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step)); + RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step, epoch_num)); RETURN_IF_NOT_OK(tree_adapter_->Launch()); MS_LOG(INFO) << "Launched a new pipeline after reset. UUID: " << tree_adapter_->tree_->GetUniqueId(); std::shared_ptr root2 = std::shared_ptr(tree_adapter_->GetRoot()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index a984f040b91..140f5f7d47d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -60,8 +60,9 @@ class TreeConsumer { /// Function to reset the current consumer to the provided step. /// The consumer will terminate the pipeline and create a new one with skip injected. /// \param step the step to reset the pipeline to. + /// \param epoch_num the epoch to reset the pipeline to. /// \return Status error code - Status Reset(int64_t step); + Status Reset(int64_t step, const int64_t epoch_num); /// Function to stop the consumer. /// \return Status error code diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index f59d41bea30..96eadb1a63b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -83,8 +83,8 @@ Status BatchOp::operator()() { RETURN_IF_NOT_OK(callback_manager_.Init(this)); // Synchronize with TaskManager TaskManager::FindMe()->Post(); - int64_t epoch_num = 0, batch_num = 0, cnt = 0; - int64_t ep_step = 0, total_step = 0; + int64_t epoch_num = op_current_epochs_; // in failover reset this can be greater than zero + int64_t ep_step = 0, total_step = 0, batch_num = 0, cnt = 0; RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); TensorRow new_row; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc index 0c9361b9e45..e292368cba5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc @@ -40,12 +40,12 @@ Status CacheLookupOp::WorkerEntry(int32_t worker_id) { RETURN_IF_NOT_OK(FetchFromCache(worker_id)); return Status::OK(); } -Status CacheLookupOp::ResetSampler() { return Status::OK(); } -Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { +Status CacheLookupOp::ResetSampler(const bool failover_reset) { return Status::OK(); } +Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) { RETURN_UNEXPECTED_IF_NULL(op); // We act like a sampler and as a dataset op. During handshake with leaf op, // We must wait until the leaf op has indexed everything. - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op)); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op, reset_count)); // Now we notify the main thread handshake has finished. leaf_op_wp_.Set(); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h index 377c6fc33b6..73b54658f56 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h @@ -41,8 +41,8 @@ class CacheLookupOp : public CacheBase, public SamplerRT { Status operator()() override; Status WorkerEntry(int32_t worker_id) override; // As a sampler, we override the following functions - Status ResetSampler() override; - Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + Status ResetSampler(const bool failover_reset = false) override; + Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0) override; Status InitSampler() override; Status GetNextSample(TensorRow *out) override; void Print(std::ostream &out, bool show_all) const override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index f8c303bc712..a6c5659e6f8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -435,6 +435,15 @@ void DatasetOp::UpdateRepeatAndEpochCounter() { MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; } +Status DatasetOp::SetEpoch(const int64_t epoch) { + CHECK_FAIL_RETURN_UNEXPECTED(epoch >= 0, + "New epoch value must be greater than or equal to 0, got: " + std::to_string(epoch)); + while (op_current_epochs_ < epoch) { + UpdateRepeatAndEpochCounter(); + } + return Status::OK(); +} + int64_t DatasetOp::GetTreeBatchSize() { if (child_.size() == 1) { return child_[0]->GetTreeBatchSize(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index ad273caa622..e4417254155 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -197,7 +197,7 @@ class DatasetOp : public std::enable_shared_from_this { // \brief During tree prepare phase, operators may have specific post-operations to perform depending on // their role. - // \notes Derived versions of this function should always call it's superclass version first + // \notes Derived versions of this function should always call their superclass version first // before providing their own implementations. virtual Status PrepareOperator(); @@ -213,6 +213,11 @@ class DatasetOp : public std::enable_shared_from_this { // \return T/F if this is an inlined operator bool inlined() const { return (oc_queue_size_ == 0); } + // \brief Set the epoch number for op manually. This is only used in reset mode. + // \param[in] epoch The new epoch number to restart the pipeline from + // \return - Status + Status SetEpoch(const int64_t epoch); + // \brief Setter function, set the number of total repeats for the operator void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc index df14459e5b7..aa225f39b2f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc @@ -47,6 +47,19 @@ ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_con shuffle_last_row_idx_(0), shuffle_buffer_state_(kShuffleStateInit) {} +Status ShuffleOp::PrepareOperator() { + // Run any common code from super class first before adding our own + RETURN_IF_NOT_OK(DatasetOp::PrepareOperator()); + + // in reset mode, we need to move forward the random generator seed. + if (GlobalContext::config_manager()->fast_recovery() && op_current_repeats_ > 0) { + for (auto i = 0; i < op_current_repeats_; i++) { + SelfReset(); + } + } + return Status::OK(); +} + // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by // itself rather than waiting for the reset driven from operators above it in the pipeline. Status ShuffleOp::SelfReset() { @@ -54,11 +67,11 @@ Status ShuffleOp::SelfReset() { // If reshuffle_each_epoch is false, then we always use the same seed for every // epoch. // If reshuffle_each_epoch is true, then the first epoch uses the given seed, - // and all subsequent epochs will then keep on using the rng_ without resetting it - if (!reshuffle_each_epoch_) { - rng_ = std::mt19937_64(shuffle_seed_); + // and we increment the seed by one in all subsequent epochs + if (reshuffle_each_epoch_) { + shuffle_seed_++; } - + rng_ = std::mt19937_64(shuffle_seed_); shuffle_buffer_ = std::make_unique(); shuffle_last_row_idx_ = 0; shuffle_buffer_state_ = kShuffleStateInit; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h index 2736fa677b8..b8526716bfd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h @@ -87,6 +87,13 @@ class ShuffleOp : public PipelineOp { // @return Name of the current Op std::string Name() const override { return kShuffleOp; } + // \brief During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // \notes Derived versions of this function should always call their superclass version first + // before providing their own implementations. + // @return Status The status code returned + Status PrepareOperator() override; + private: // Private function to add a new row to the shuffle buffer. // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc index 78067789ae8..1a8159f9d6a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -50,6 +50,9 @@ Status SkipOp::GetNextRow(TensorRow *row) { bool eoe_received = false; while (skip_count_ < max_skips_) { RETURN_IF_NOT_OK(child_[0]->GetNextRow(row)); + if (row->eof()) { + return Status::OK(); + } if (row->eoe() && !once_only_) { eoe_received = true; break; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 2d7352a0ec1..e524664d9f0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -54,7 +54,10 @@ void GeneratorOp::Print(std::ostream &out, bool show_all) const { // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows Status GeneratorOp::InitSampler() { if (sampler_ != nullptr) { - return sampler_->HandshakeRandomAccessOp(this); + // Let the sampler know if we are resetting the pipeline to a specific epoch (op_current_repeats_ > 0) + // to mimic the behaviour in that state and have repeatability. + // Note that number of repeats is used since in each epoch we may reset sampler multiple times. + return sampler_->HandshakeRandomAccessOp(this, op_current_repeats_); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc index d0d5d2d9d71..dba56ed285a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mappable_leaf_op.cc @@ -68,7 +68,8 @@ Status MappableLeafOp::operator()() { RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); TensorRow sample_row; RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row)); - while (true) { // each iteration is 1 epoch, breaks when IsLastIteration() is true + while (true) { // each iteration is 1 repeat (usually =1 epoch, unless we have a repeat node above us), breaks when + // IsLastIteration() is true if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) { ep_step = 0; RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); @@ -114,8 +115,10 @@ Status MappableLeafOp::Reset() { // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows Status MappableLeafOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); + // Let the sampler know if we are resetting the pipeline to a specific epoch (op_current_repeats_ > 0) + // to mimic the behaviour in that state and have repeatability. + // Note that number of repeats is used since in each epoch we may reset sampler multiple times. + return sampler_->HandshakeRandomAccessOp(this, op_current_repeats_); } // contains the main logic of pulling a IOBlock from IOBlockQueue, load a row and push the row to out_connector_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.cc index 4fa96c2ed8d..66794cb40a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" +#include #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/engine/datasetops/source/io_block.h" @@ -39,7 +40,9 @@ NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connect load_io_block_queue_(true), shuffle_files_(shuffle_files), num_rows_per_shard_(0), - num_rows_(0) { + num_rows_(0), + shuffled_keys_({}), + seed_(0) { worker_connector_size_ = worker_connector_size; } @@ -244,22 +247,15 @@ bool NonMappableLeafOp::NeedPushFileToBlockQueue(const std::string &file_name, i return push; } -void NonMappableLeafOp::ShuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); +void NonMappableLeafOp::ShuffleKeys() { + std::mt19937 rng(num_devices_ == 1 ? GetSeed() : ++seed_); + std::shuffle(shuffled_keys_.begin(), shuffled_keys_.end(), rng); } Status NonMappableLeafOp::WaitToFillIOBlockQueue() { // must be called first if called by worker spanwed by taskgroup TaskManager::FindMe()->Post(); - std::vector i_keys; - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; while (true) { RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); io_block_queue_wait_post_.Clear(); @@ -269,9 +265,27 @@ Status NonMappableLeafOp::WaitToFillIOBlockQueue() { } if (shuffle_files_) { - ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + ShuffleKeys(); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(shuffled_keys_)); + } + return Status::OK(); +} + +Status NonMappableLeafOp::PrepareOperator() { + // Run any common code from super class first before adding our own + RETURN_IF_NOT_OK(DatasetOp::PrepareOperator()); + + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + shuffled_keys_.push_back(it.key()); + } + // in reset mode, shuffled_keys needs to be ordered in the rsetting epoch + if (GlobalContext::config_manager()->fast_recovery() && op_current_repeats_ > 0) { + for (auto i = 0; i < op_current_repeats_; i++) { + ShuffleKeys(); + } } - RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h index 4b5fe5c6881..9c77974a9c6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h @@ -79,6 +79,13 @@ class NonMappableLeafOp : public ParallelOp, TensorRow> // @return Name of the current Op std::string Name() const override { return "NonMappableLeafOp"; } + // \brief During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // \notes Derived versions of this function should always call their superclass version first + // before providing their own implementations. + // @return Status The status code returned + Status PrepareOperator() override; + protected: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -135,7 +142,7 @@ class NonMappableLeafOp : public ParallelOp, TensorRow> // @return Status - the error code returned. virtual Status CalculateNumRowsPerShard() = 0; - static void ShuffleKeys(std::vector *i_keys, uint32_t seed); + void ShuffleKeys(); // Fill the IOBlockQueue. // @para i_keys - keys of file to fill to the IOBlockQueue @@ -159,6 +166,10 @@ class NonMappableLeafOp : public ParallelOp, TensorRow> bool shuffle_files_; int64_t num_rows_per_shard_; int64_t num_rows_; + + private: + std::vector shuffled_keys_; // to store shuffled filename indices + uint32_t seed_; // used to shuffle filename indices }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 89b458f5e2c..cb4b8f4526b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -157,8 +157,9 @@ Status DistributedSamplerRT::GetNextSample(TensorRow *out) { return Status::OK(); } -Status DistributedSamplerRT::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_tensor_, "[Internal ERROR] Reset() Sampler called early or late."); +Status DistributedSamplerRT::ResetSampler(const bool failover_reset) { + CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || cnt_ == samples_per_tensor_, + "[Internal ERROR] ResetSampler() called early or late."); cnt_ = 0; if (shuffle_ == true) { @@ -168,7 +169,7 @@ Status DistributedSamplerRT::ResetSampler() { } if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h index 3983eca85d6..48f895a18bc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -55,9 +55,10 @@ class DistributedSamplerRT : public SamplerRT { /// Init sampler, called by base class or python Status InitSampler() override; - /// \brief for next epoch of sampleIds - /// \return Status code - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; int64_t GetDeviceID() { return device_id_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc index 1f19e194f63..9adee038087 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc @@ -63,7 +63,7 @@ Status MindRecordSamplerRT::InitSampler() { return Status::OK(); } -Status MindRecordSamplerRT::ResetSampler() { +Status MindRecordSamplerRT::ResetSampler(const bool failover_reset) { // drive the shard reader reshuffle tasks to redo the sampling for another epoch // Note that when cache is attached, this function is driven by cache lookup op rather than mindrecord op. // Therefore, the reshuffle of tasks might happen in the middle of mindrecord's epoch diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h index 1b32987a9df..2a4836a7760 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h @@ -44,9 +44,10 @@ class MindRecordSamplerRT : public SamplerRT { // meant to be called by base class or python Status InitSampler() override; - // for next epoch of sampleIds - // @return Status The status code returned - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; void SamplerPrint(std::ostream &out, bool show_all) const override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc index cc593cc047e..febd97a4a86 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -105,22 +105,29 @@ Status PKSamplerRT::GetNextSample(TensorRow *out) { return Status::OK(); } -Status PKSamplerRT::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late."); +Status PKSamplerRT::ResetSampler(const bool failover_reset) { + CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || next_id_ == num_samples_, + "[Internal ERROR] ResetSampler() called early or late."); next_id_ = 0; rnd_.seed(seed_++); if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); } -Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { +Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) { RETURN_UNEXPECTED_IF_NULL(op); RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); RETURN_IF_NOT_OK(InitSampler()); + // Move forward sampler's random generator if resetting the pipeline in fast_recovery mode + if (GlobalContext::config_manager()->fast_recovery()) { + for (auto i = 0; i < reset_count; i++) { + RETURN_IF_NOT_OK(ResetSampler(true)); // failover_reset = true + } + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h index 9e184ecebb3..f9330ff2a96 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -46,15 +46,17 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED // first handshake between leaf source op and Sampler. This func will determine the amount of data // in the dataset that we can sample from. // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is + // @param reset_count - reset the random generator these many times (used in fast_recovery mode of reset) // @return - Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0) override; // init sampler, to be called by python or Handshake Status InitSampler() override; - // for next epoch of sampleIds - // @return Status The status code returned - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; // Printer for debugging purposes. // @param out - output stream to write to diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc index 6753e06b24d..11bb5ace38a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -99,8 +99,11 @@ Status PythonSamplerRT::InitSampler() { return Status::OK(); } -Status PythonSamplerRT::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "[Internal ERROR] Reset() Sampler called early or late."); +Status PythonSamplerRT::ResetSampler(const bool failover_reset) { + if (failover_reset) { + return Status::OK(); + } + CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "[Internal ERROR] ResetSampler() called early or late."); need_to_reset_ = false; py::gil_scoped_acquire gil_acquire; if (Py_IsInitialized() == 0) { @@ -113,7 +116,7 @@ Status PythonSamplerRT::ResetSampler() { } if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h index 466fbde2fac..1ec1c621fc3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -40,9 +40,10 @@ class PythonSamplerRT : public SamplerRT { // @return Status Status InitSampler() override; - // for next epoch of sampleIds - // @return Status The status code returned - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; // Op calls this to get next Sample that contains all the sampleIds // @param TensorRow to be returned to corresponding Dataset Op diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc index 6c94e44ea60..dd68de681cd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -101,8 +101,9 @@ Status RandomSamplerRT::InitSampler() { return Status::OK(); } -Status RandomSamplerRT::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late."); +Status RandomSamplerRT::ResetSampler(const bool failover_reset) { + CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || next_id_ == num_samples_, + "[Internal ERROR] ResetSampler() called early or late."); next_id_ = 0; if (reshuffle_each_epoch_) { @@ -116,7 +117,7 @@ Status RandomSamplerRT::ResetSampler() { } if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h index e70edda3bec..5ffd1c3d27a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -46,9 +46,10 @@ class RandomSamplerRT : public SamplerRT { // meant to be called by base class or python Status InitSampler() override; - // for next epoch of sampleIds - // @return Status The status code returned - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; void SamplerPrint(std::ostream &out, bool show_all) const override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index 284014fa3e9..7099b886b76 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -41,7 +41,7 @@ SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_tensor) col_desc_(nullptr), is_initialized(false) {} -Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { +Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) { RETURN_UNEXPECTED_IF_NULL(op); std::shared_ptr child_sampler; if (HasChildSampler()) { @@ -52,7 +52,7 @@ Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { } // Handshake and init child first. - RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); + RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op, reset_count)); } CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "[Internal ERROR] RandomAccessOp init failed, as it is nullptr."); @@ -67,7 +67,12 @@ Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { // It's up to the derived class to check the validity of the two args // Because some sampler only needs one of the arg (weighted_random_sampler) RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback - + // Move forward sampler's random generator if resetting the pipeline in fast_recovery mode + if (GlobalContext::config_manager()->fast_recovery()) { + for (auto i = 0; i < reset_count; i++) { + RETURN_IF_NOT_OK(ResetSampler(true)); // failover_reset = true + } + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h index d68ba0edfad..40765bc95c2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -91,15 +91,20 @@ class SamplerRT { Status GetAllIdsThenReset(py::array *data); #endif - // for next epoch of sampleIds - // @return Status The status code returned - virtual Status ResetSampler() = 0; + /// \brief Reset for next epoch. + /// \note If failover_reset is set, any override of this function must support the scenario where consecutive calls to + /// it are executed successfully (to prepare the sampler for a specific epoch, including updating any random + /// generator's internal state) + /// \param[in] failover_reset - A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + virtual Status ResetSampler(const bool failover_reset = false) = 0; // first handshake between leaf source op and Sampler. This func will determine the amount of data // in the dataset that we can sample from. // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is - // @return - virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); + // @param reset_count - reset the random generator these many times (used in fast_recovery mode of reset) + // @return status error code + virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0); // initialize sampler and perform checks on certain vars virtual Status InitSampler() { return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index 7aad5d123d0..2be5b4fa17c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -93,13 +93,14 @@ Status SequentialSamplerRT::InitSampler() { return Status::OK(); } -Status SequentialSamplerRT::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late."); +Status SequentialSamplerRT::ResetSampler(const bool failover_reset) { + CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || id_count_ == num_samples_, + "[Internal ERROR] ResetSampler() called early or late."); current_id_ = start_index_; id_count_ = 0; if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 6dd6828b6e0..6752f67edb5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -39,9 +39,10 @@ class SequentialSamplerRT : public SamplerRT { // init sampler, called by python Status InitSampler() override; - // for next epoch of sampleIds - // @return Status The status code returned - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; // Op calls this to get next Sample that contains all the sampleIds // @param TensorRow to be returned to corresponding Dataset Op diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc index 56e19d0c51b..85cca41bec8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.cc @@ -19,26 +19,30 @@ namespace mindspore { namespace dataset { -Status SkipFirstEpochSamplerRT::ResetSampler() { - if (id_count_ != num_samples_) { - std::string err_msg = - "[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) + - " num_samples_: " + std::to_string(num_samples_); - MS_LOG(ERROR) << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } - current_id_ = 0; - id_count_ = 0; +Status SkipFirstEpochSamplerRT::ResetSampler(const bool failover_reset) { + // This is a special sampler for Failover Reset, its internal state should + // not reset when failover_reset is set to true. + if (!failover_reset) { + if (id_count_ != num_samples_) { + std::string err_msg = + "[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) + + " num_samples_: " + std::to_string(num_samples_); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + current_id_ = 0; + id_count_ = 0; - if (!first_epoch_done_) { - num_samples_ += start_index_; - start_index_ = 0; - samples_per_tensor_ = num_samples_; - first_epoch_done_ = true; + if (!first_epoch_done_) { + num_samples_ += start_index_; + start_index_ = 0; + samples_per_tensor_ = num_samples_; + first_epoch_done_ = true; + } } if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h index 71b602c7b37..87fc9acab0b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/skip_first_epoch_sampler.h @@ -30,9 +30,10 @@ class SkipFirstEpochSamplerRT : public SequentialSamplerRT { // Destructor. ~SkipFirstEpochSamplerRT() = default; - // for next epoch of sampleIds - // @return Status The status code returned - Status ResetSampler() override; + /// \brief Reset for next epoch. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; /// \brief Gets the number of samples available /// \note Since this sampler returns different number of samples in the first epoch (compared to other epochs), this diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index 6473923dcb9..48d36dd44e7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -46,12 +46,12 @@ Status SubsetRandomSamplerRT::InitSampler() { } // Reset the internal variable to the initial state. -Status SubsetRandomSamplerRT::ResetSampler() { +Status SubsetRandomSamplerRT::ResetSampler(const bool failover_reset) { // Randomized the indices again. rand_gen_.seed(GetSeed()); std::shuffle(indices_.begin(), indices_.end(), rand_gen_); - return SubsetSamplerRT::ResetSampler(); + return SubsetSamplerRT::ResetSampler(failover_reset); } void SubsetRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index e3786b8f100..61bf702d039 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -43,9 +43,10 @@ class SubsetRandomSamplerRT : public SubsetSamplerRT { /// \return Status Status InitSampler() override; - /// Reset the internal variable to the initial state and reshuffle the indices. - /// \return Status - Status ResetSampler() override; + /// \brief Reset the internal variable(s) to the initial state and reshuffle the indices. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; /// Printer for debugging purposes. /// \param out - output stream to write to diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc index 237d569b1cc..28e8fe31f1d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.cc @@ -48,12 +48,12 @@ Status SubsetSamplerRT::InitSampler() { } // Reset the internal variable to the initial state. -Status SubsetSamplerRT::ResetSampler() { +Status SubsetSamplerRT::ResetSampler(const bool failover_reset) { // Reset the internal counters. sample_id_ = 0; if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h index 4fc8bfe66f5..7866a7d874a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h @@ -42,9 +42,10 @@ class SubsetSamplerRT : public SamplerRT { /// \return Status Status InitSampler() override; - /// Reset the internal variable to the initial state and reshuffle the indices. - /// \return Status - Status ResetSampler() override; + /// \brief Reset the internal variable(s) to the initial state and reshuffle the indices. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; /// Get the sample ids. /// \param[out] TensorRow where the sample ids will be placed. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 24fef1b968e..7946827fbf7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -94,7 +94,7 @@ void WeightedRandomSamplerRT::InitOnePassSampling() { } // Reset the internal variable to the initial state and reshuffle the indices. -Status WeightedRandomSamplerRT::ResetSampler() { +Status WeightedRandomSamplerRT::ResetSampler(const bool failover_reset) { sample_id_ = 0; rand_gen_.seed(GetSeed()); if (!replacement_) { @@ -104,7 +104,7 @@ Status WeightedRandomSamplerRT::ResetSampler() { } if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset)); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index 27f9ca5408e..da2a00becb5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -45,8 +45,10 @@ class WeightedRandomSamplerRT : public SamplerRT { // @return Status Status InitSampler() override; - // Reset the internal variable to the initial state and reshuffle the indices. - Status ResetSampler() override; + /// \brief Reset the internal variable(s) to the initial state and reshuffle the indices. + /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline + /// \return Status The status code returned + Status ResetSampler(const bool failover_reset = false) override; // Get the sample ids. // @param[out] TensorRow where the sample ids will be placed. diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/add_skip_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/add_skip_pass.cc index 77c34cbdd88..c80a82c8e4e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/add_skip_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/add_skip_pass.cc @@ -87,7 +87,8 @@ Status AddSkipPass::RunOnTree(std::shared_ptr root_ir, bool *const CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Failed to inject SkipOp."); int64_t dataset_size = -1; - RETURN_IF_NOT_OK(root_ir->GetDatasetSize(nullptr, false, &dataset_size)); + std::shared_ptr size_getter = std::make_shared(); + RETURN_IF_NOT_OK(root_ir->GetDatasetSize(size_getter, false, &dataset_size)); CHECK_FAIL_RETURN_UNEXPECTED(dataset_size > 0, "Cannot reset the pipeline, dataset size is undefined"); int32_t num_epochs = finder.GetNumEpochs(); int64_t step = finder.GetStep(); @@ -105,11 +106,7 @@ Status AddSkipPass::RunOnTree(std::shared_ptr root_ir, bool *const } // in fast recovery, we start from current epoch and skip remaining steps (skip node will also be pushed down) if (GlobalContext::config_manager()->fast_recovery()) { - int32_t new_num_epochs = num_epochs - static_cast(step / dataset_size); int64_t skip_num = step % dataset_size; - - root_ir->SetNumEpochs(new_num_epochs); - auto skip_node = std::make_shared(skip_num); skip_node->SetOnceOnly(true); RETURN_IF_NOT_OK(node->InsertAbove(skip_node)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index 0483a5f71d1..981a4ee0ba2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -170,7 +170,7 @@ Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr ir, std return Status::OK(); } -Status TreeAdapter::Build(std::shared_ptr root_ir) { +Status TreeAdapter::Build(std::shared_ptr root_ir, int64_t epoch_num) { RETURN_UNEXPECTED_IF_NULL(root_ir); // Create ExecutionTree tree_ = std::make_unique(); @@ -180,6 +180,10 @@ Status TreeAdapter::Build(std::shared_ptr root_ir) { RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); + if (usage_ == kDeReset) { + RETURN_IF_NOT_OK(AdjustReset(epoch_num)); + } + // Prepare the tree RETURN_IF_NOT_OK(tree_->Prepare()); @@ -188,7 +192,8 @@ Status TreeAdapter::Build(std::shared_ptr root_ir) { return Status::OK(); } -Status TreeAdapter::Compile(const std::shared_ptr &input_ir, int32_t num_epochs, int64_t step) { +Status TreeAdapter::Compile(const std::shared_ptr &input_ir, int32_t num_epochs, int64_t step, + const int64_t epoch_num) { RETURN_UNEXPECTED_IF_NULL(input_ir); input_ir_ = input_ir; tree_state_ = kCompileStateIRGraphBuilt; @@ -227,11 +232,21 @@ Status TreeAdapter::Compile(const std::shared_ptr &input_ir, int32_ // Remember the root node root_ir_ = root_ir; - RETURN_IF_NOT_OK(Build(root_ir_)); + RETURN_IF_NOT_OK(Build(root_ir_, epoch_num)); tree_state_ = kCompileStateReady; return Status::OK(); } +Status TreeAdapter::AdjustReset(const int64_t epoch_num) { + if (GlobalContext::config_manager()->fast_recovery() && epoch_num > 0) { + MS_LOG(INFO) << "Adjusting dataset pipeline for failover reset to start on epoch: " << (epoch_num + 1); + for (auto op = tree_->begin(); op != tree_->end(); op++) { + op->SetEpoch(epoch_num); + } + } + return Status::OK(); +} + Status TreeAdapter::GetNext(TensorRow *row) { RETURN_UNEXPECTED_IF_NULL(tree_); RETURN_UNEXPECTED_IF_NULL(row); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 19a57c3495e..0c6720f21c3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -57,7 +57,8 @@ class TreeAdapter { // This function performs syntax checking, semantics checking, optimizes, and then builds // the Execution tree. - Status Compile(const std::shared_ptr &input_ir, int32_t num_epochs = -1, int64_t step = 0); + Status Compile(const std::shared_ptr &input_ir, int32_t num_epochs = -1, int64_t step = 0, + const int64_t epoch_num = 0); // Return the root node of the IR after cloned from the parsed IR tree std::shared_ptr RootIRNode() const { return root_ir_; } @@ -115,11 +116,14 @@ class TreeAdapter { Status PostPass(std::shared_ptr ir); // Build an Execution tree - Status Build(std::shared_ptr root_ir); + Status Build(std::shared_ptr root_ir, const int64_t epoch_num = 0); // This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree. Status BuildExecutionTreeRecur(std::shared_ptr ir, std::shared_ptr *op); + // Adjust the pipeline (eg, move rng_ forward) if in reset mode + Status AdjustReset(const int64_t epoch_num); + std::unordered_map column_name_map_; std::shared_ptr input_ir_; std::shared_ptr root_ir_; diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index 085965d5bec..1725866725a 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -115,16 +115,17 @@ def _get_training_dataset(): return _train_dataset -def _reset_training_dataset(step): +def _reset_training_dataset(step, epoch): """ - Reset the training dataset to the given step number. + Reset the training dataset to the given step and epoch number. Args: step (int): Global step number. + epoch (int): Global epoch number """ dataset = _get_training_dataset() if dataset is not None: - dataset._reset(step) # pylint: disable=W0212 + dataset._reset(step, epoch) # pylint: disable=W0212 else: raise RuntimeError("Training dataset is not set.") diff --git a/mindspore/python/mindspore/dataset/engine/iterators.py b/mindspore/python/mindspore/dataset/engine/iterators.py index 554ce5434bf..e07dd08de0f 100644 --- a/mindspore/python/mindspore/dataset/engine/iterators.py +++ b/mindspore/python/mindspore/dataset/engine/iterators.py @@ -169,14 +169,15 @@ class Iterator: self._col_names = self.__ori_dataset.get_col_names() return self._col_names - def _reset(self, step): + def _reset(self, step, epoch): """ - Reset the iterator to the given step number. + Reset the iterator to the given step number and epoch number. Args: - step (int): Global step number. + step (int): Global step number + epoch (int): Global epoch number """ - self._iterator.Reset(step) + self._iterator.Reset(step, epoch) def _transform_md_to_output(self, t): if self._output_numpy: diff --git a/mindspore/python/mindspore/train/model.py b/mindspore/python/mindspore/train/model.py index 43886086fee..3413f266dea 100644 --- a/mindspore/python/mindspore/train/model.py +++ b/mindspore/python/mindspore/train/model.py @@ -829,7 +829,7 @@ class Model: os.remove(cb_params.latest_ckpt_file) raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\ + cb_params.latest_ckpt_file) from e - _reset_training_dataset(cb_params.cur_step_num) + _reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num) self.need_load_ckpt = False def _reset_training_step_for_normal_process(self, cb_params, dataset_helper): @@ -858,9 +858,9 @@ class Model: self.epoch_iter = recovery_epoch_num cb_params.cur_epoch_num = self.epoch_iter + 1 cb_params.last_save_ckpt_step = cb_params.cur_step_num - _reset_training_dataset(cb_params.cur_step_num) + _reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num) else: - _reset_training_dataset(0) + _reset_training_dataset(0, 0) _set_recovery_context(need_reset=False) diff --git a/tests/ut/cpp/dataset/c_api_dataset_config_test.cc b/tests/ut/cpp/dataset/c_api_dataset_config_test.cc index ea7f1ede9e7..ac8eb260457 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_config_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_config_test.cc @@ -183,7 +183,7 @@ TEST_F(MindDataTestPipeline, TestCallShuffleTwice) { uint32_t original_seed = config::get_seed(); uint32_t original_num_parallel_workers = config::get_num_parallel_workers(); MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; - config::set_seed(654); + config::set_seed(655); // not all seeds satisfy the assertions in this test. config::set_num_parallel_workers(1); // Create a TextFile Dataset with single text file which has three samples diff --git a/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz b/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz index 6e2486024fd..38b50438af3 100644 Binary files a/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz and b/tests/ut/data/dataset/golden/test_2ops_shuffle_repeat.npz differ diff --git a/tests/ut/python/dataset/test_reset.py b/tests/ut/python/dataset/test_reset.py index e0605ab87c8..97aa982fb7d 100644 --- a/tests/ut/python/dataset/test_reset.py +++ b/tests/ut/python/dataset/test_reset.py @@ -24,6 +24,10 @@ from util_minddataset import add_and_remove_cv_file # pylint: disable=no-value-for-parameter +# Need to run all these tests in separate processes since MD internally stores +# "training" dataset in a global variable every time. +pytestmark = pytest.mark.forked + def create_np_dataset(size): dimensions = (size, 4, 3, 2) @@ -77,14 +81,13 @@ def create_random_imagenet_dataset(repeat_size, sampler=None, num_parallel_worke data = data.repeat(repeat_size) crop_op1 = vision.RandomCrop(4) operations = [vision.Decode(to_pil=to_pil), crop_op1] - if to_pil: # include a pyfunc in test if to_pil is True + if to_pil: # include a pyfunc in test if to_pil is True operations.append(lambda x: x.rotate(45)) data = data.map(operations=operations, input_columns=[ "image"], num_parallel_workers=num_parallel_workers, python_multiprocessing=True) if batch_func: - data = data.batch( - batch_size=2, per_batch_map=batch_func, - num_parallel_workers=num_parallel_workers, python_multiprocessing=True) + data = data.batch(batch_size=2, per_batch_map=batch_func, input_columns=["label"], + num_parallel_workers=num_parallel_workers, python_multiprocessing=True) data = data.project(["image"]) return data @@ -98,7 +101,7 @@ def create_minddata_dataset(size): return data -def run_reset(data, num_epochs, failure_point: int, reset_step: int): +def run_reset(data, num_epochs: int, failure_point: int): size = data.get_dataset_size() expected = [] expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) @@ -107,50 +110,51 @@ def run_reset(data, num_epochs, failure_point: int, reset_step: int): expected.append(d) del expected_itr - actual_before_reset = [] + expected2 = [] itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212 cur_step: int = 0 failed = False for _ in range(num_epochs): for d in itr: - actual_before_reset.append(d) - if cur_step == failure_point: - ds.engine.datasets._reset_training_dataset(reset_step) # pylint: disable=W0212 + expected2.append(d) + if cur_step + 1 == failure_point: + # pylint: disable=W0212 + ds.engine.datasets._reset_training_dataset(failure_point, failure_point // size) failed = True break cur_step += 1 if failed: break - actual_after_reset = [] if failed: - for _ in range(reset_step // size, num_epochs): + for _ in range(failure_point // size, num_epochs): for d in itr: - actual_after_reset.append(d) + expected2.append(d) with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."): for _ in itr: - pass + expected2.append(d) - for x, y in zip(expected[:failure_point], actual_before_reset): - np.testing.assert_array_equal(x, y) - - for x, y in zip(expected[reset_step:], actual_after_reset): + assert len(expected) == len(expected2) + for x, y in zip(expected, expected2): np.testing.assert_array_equal(x, y) def run_reset_error(data, num_epochs: int, failure_point: int): itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) # pylint: disable=unused-variable ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212 + dataset_size = data.get_dataset_size() if failure_point > 0: with pytest.raises(RuntimeError) as err: - ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212 + # pylint: disable=W0212 + ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size) assert "Cannot reset the pipeline, reset step must be less than dataset_size * num_epochs." in str(err.value) else: with pytest.raises(RuntimeError) as err: - ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212 + # pylint: disable=W0212 + ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size) assert "Cannot reset the pipeline, reset step must be >= 0." in str(err.value) @@ -165,8 +169,7 @@ def test_reset_np(): failure_steps = (dataset_size * num_epochs) // 10 data = create_np_dataset(size=dataset_size) for failure_point in range(0, dataset_size * num_epochs, failure_steps): - for reset_step in range(0, dataset_size * num_epochs, failure_steps): - run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step) + run_reset(data, num_epochs=num_epochs, failure_point=failure_point) def test_reset_cifar1(): @@ -180,8 +183,7 @@ def test_reset_cifar1(): failure_steps = (dataset_size * num_epochs) // 5 data = create_cifar_dataset1(size=dataset_size) for failure_point in range(0, dataset_size * num_epochs, failure_steps): - for reset_step in range(0, dataset_size * num_epochs, failure_steps): - run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step) + run_reset(data, num_epochs=num_epochs, failure_point=failure_point) def test_reset_cifar2(): @@ -195,8 +197,7 @@ def test_reset_cifar2(): failure_steps = (dataset_size * num_epochs) // 5 data = create_cifar_dataset2(size=dataset_size) for failure_point in range(0, dataset_size * num_epochs, failure_steps): - for reset_step in range(0, dataset_size * num_epochs, failure_steps): - run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step) + run_reset(data, num_epochs=num_epochs, failure_point=failure_point) def test_reset_imagenet(): @@ -210,8 +211,7 @@ def test_reset_imagenet(): failure_steps = (dataset_size * num_epochs) // 4 data = create_imagenet_dataset(size=dataset_size) for failure_point in range(0, dataset_size * num_epochs, failure_steps): - for reset_step in range(0, dataset_size * num_epochs, failure_steps): - run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step) + run_reset(data, num_epochs=num_epochs, failure_point=failure_point) def test_reset_mindrecord(add_and_remove_cv_file): # pylint: disable=unused-argument, redefined-outer-name @@ -225,8 +225,7 @@ def test_reset_mindrecord(add_and_remove_cv_file): # pylint: disable=unused-arg failure_steps = (dataset_size * num_epochs) // 10 data = create_minddata_dataset(size=dataset_size) for failure_point in range(0, dataset_size * num_epochs, failure_steps): - for reset_step in range(0, dataset_size * num_epochs, failure_steps): - run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step) + run_reset(data, num_epochs=num_epochs, failure_point=failure_point) def test_reset_np_error(): @@ -243,13 +242,13 @@ def test_reset_np_error(): run_reset_error(data, num_epochs=num_epochs, failure_point=failure_point) -def random_col(col1, col2, batch_info): - return ([np.random.rand(1) for a in col1], [np.random.rand(1) for b in col2]) +def random_col(col1, batch_info): + return ([np.random.rand(1) for a in col1],) @pytest.mark.parametrize("num_parallel_workers", (4, 5)) @pytest.mark.parametrize("sampler", (ds.RandomSampler(), None)) -@pytest.mark.parametrize("to_pil, batch_func", [(False, None), (True, random_col)]) # test C ops and Python ops (MP) +@pytest.mark.parametrize("to_pil, batch_func", [(False, None), (True, random_col)]) # test C ops and Python ops (MP) def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_func): """ Feature: Dataset recovery @@ -277,7 +276,7 @@ def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_ del expected_itr dataset_size = data.get_dataset_size() # try different failure points - for failure_point in (5, 6, 19, 22): + for failure_point in (5, 6, 22): expected2 = [] expected2_itr = data.create_tuple_iterator( num_epochs=num_epochs, output_numpy=True) @@ -291,23 +290,22 @@ def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_ failure = True break if failure: - ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212 + # pylint: disable=W0212 + ds.engine.datasets._reset_training_dataset(failure_point, failure_point // dataset_size) failure = False for d in expected2_itr: expected2.append(d) del expected2_itr # verify count and values of failover with original run - assert len(expected) == len(expected2) - for a, b in zip(expected, expected2): - assert np.array_equal(a[0], b[0]) + np.testing.assert_array_equal(expected, expected2) ds.config.set_seed(original_seed) ds.config.set_fast_recovery(original_fast_recovery) ds.config.set_enable_shared_mem(original_shared_mem) -@pytest.mark.parametrize("to_pil", (False, True)) # test C ops and Python ops with MP=true +@pytest.mark.parametrize("to_pil", (False, True)) # test C ops and Python ops with MP=true @pytest.mark.parametrize("num_parallel_workers", (4, 5)) @pytest.mark.parametrize("shard_id", (0, 1, 2, 3)) def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil): @@ -345,8 +343,7 @@ def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil): # try different failure points for failure_point in (3, 7, 9): expected2 = [] - expected2_itr = data.create_tuple_iterator( - num_epochs=num_epochs, output_numpy=True) + expected2_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) ds.engine.datasets._set_training_dataset(expected2_itr) # pylint: disable=W0212 failure = False for epoch in range(num_epochs): @@ -356,21 +353,225 @@ def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil): failure = True break if failure: - ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212 + ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212 failure = False for d in expected2_itr: expected2.append(d) # verify count and values of failover with original run - assert len(expected) == len(expected2) - for a, b in zip(expected, expected2): - assert np.array_equal(a, b) + np.testing.assert_array_equal(expected, expected2) ds.config.set_seed(original_seed) ds.config.set_fast_recovery(original_fast_recovery) ds.config.set_enable_shared_mem(original_shared_mem) +def test_reset_shuffle(): + """ + Feature: Dataset recovery + Description: The random generator in shuffle operation resets to correct internal state + Expectation: Same dataset after reset + """ + original_seed = ds.config.get_seed() + original_fast_recovery = ds.config.get_fast_recovery() + ds.config.set_seed(1) + ds.config.set_fast_recovery(True) + + source = [(np.array([x])) for x in range(10)] + data1 = ds.NumpySlicesDataset(source, ["data"], sampler=ds.SequentialSampler()) + data1 = data1.shuffle(3) + data1 = data1.skip(1) + num_epochs = 3 + + expected = [] + expected_itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) + for epoch in range(num_epochs): + for step, d in enumerate(expected_itr): + expected.append(d) + + failure_point = 13 + expected2 = [] + expected2_itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) + ds.engine.datasets._set_training_dataset(expected2_itr) # pylint: disable=W0212 + failure = False + for epoch in range(num_epochs): + for step, d in enumerate(expected2_itr): + expected2.append(d) + if epoch * data1.get_dataset_size() + step + 1 == failure_point: + failure = True + break + if failure: + ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212 + failure = False + for step, d in enumerate(expected2_itr): + expected2.append(d) + + with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."): + for step, d in enumerate(expected2_itr): + expected2.append(d) + np.testing.assert_array_equal(expected, expected2) + + ds.config.set_seed(original_seed) + ds.config.set_fast_recovery(original_fast_recovery) + + +@pytest.mark.parametrize("sampler", (ds.RandomSampler(), ds.SequentialSampler())) +def test_reset_sampler(sampler): + """ + Feature: Dataset recovery + Description: The samplers for source operations reset to correct internal state. + Expectation: Same dataset after reset + """ + original_seed = ds.config.get_seed() + original_fast_recovery = ds.config.get_fast_recovery() + ds.config.set_seed(1) + ds.config.set_fast_recovery(True) + + source = [(np.array([x]),) for x in range(10)] + data1 = ds.NumpySlicesDataset(source, ["data"], sampler=sampler) + data1 = data1.skip(1) + data1 = data1.repeat(2) + data1 = data1.skip(1) + num_epochs = 3 + + expected_itr = data1.create_tuple_iterator( + num_epochs=num_epochs, output_numpy=True) + expected = [] + for epoch in range(num_epochs): + for step, d in enumerate(expected_itr): + expected.append(d) + + failure_point = 13 + expected2 = [] + expected2_itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) + ds.engine.datasets._set_training_dataset(expected2_itr) # pylint: disable=W0212 + failure = False + + for epoch in range(num_epochs): + for step, d in enumerate(expected2_itr): + expected2.append(d) + if epoch * data1.get_dataset_size() + step + 1 == failure_point: + failure = True + break + if failure: + ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212 + failure = False + for step, d in enumerate(expected2_itr): + expected2.append(d) + + with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."): + for step, d in enumerate(expected2_itr): + expected2.append(d) + np.testing.assert_array_equal(expected, expected2) + + ds.config.set_seed(original_seed) + ds.config.set_fast_recovery(original_fast_recovery) + + +@pytest.mark.parametrize("fast_recovery", (False, True)) +def test_reset_batch(fast_recovery): + """ + Feature: Dataset recovery + Description: The BatchInfo argument of batch operation contains correct information (epoch num) + Expectation: Test succeeds + """ + original_fast_recovery = ds.config.get_fast_recovery() + ds.config.set_fast_recovery(fast_recovery) + + num_epochs = 5 + repeat_size = 4 + skip_size = 12 + + def get_epoch_num(col1, batch_info): + return (np.array(batch_info.get_epoch_num()),) + + data1 = ds.NumpySlicesDataset(np.arange(10).reshape(10, 1)) + data1 = data1.repeat(repeat_size) + data1 = data1.skip(skip_size) + data1 = data1.batch(batch_size=7, per_batch_map=get_epoch_num, num_parallel_workers=1, python_multiprocessing=False) + + itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) + ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212 + + failure = False + failure_point = 25 + expected = np.repeat(np.arange(5), 4).reshape((20, 1)) + expected2 = [] + + for epoch in range(num_epochs): + for step, d in enumerate(itr): + expected2.append(d) + if epoch * data1.get_dataset_size() + step + 1 == failure_point: + failure = True + break + if failure: + ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212 + failure = False + for step, d in enumerate(itr): + expected2.append(d) + + with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."): + for d in itr: + expected2.append(d) + np.testing.assert_array_equal(expected, expected2) + + ds.config.set_fast_recovery(original_fast_recovery) + + +def test_reset_nonmappable(): + """ + Feature: Dataset recovery + Description: The order of rows read in normal and reset runs are identical for a TFRecord dataset. + Expectation: Test succeeds + """ + original_seed = ds.config.get_seed() + original_fast_recovery = ds.config.get_fast_recovery() + + num_epochs = 10 + num_repeats = 5 + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", + "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] + + # run a pipeline and collect rows + def get_res(shard_id, num_repeats, failure_point): + ds.config.set_seed(1) + ds.config.set_fast_recovery(True) + + data1 = ds.TFRecordDataset(tf_files, num_shards=4, shard_id=shard_id, num_samples=5, shuffle=ds.Shuffle.FILES) + data1 = data1.repeat(num_repeats) + itr = data1.create_dict_iterator(num_epochs=num_epochs, output_numpy=True) + ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212 + dataset_size = data1.get_dataset_size() + + res = list() + failure = False + for epoch in range(num_epochs): + for step, item in enumerate(itr): + res.append(item["scalars"][0]) + if epoch * dataset_size + step + 1 == failure_point: + failure = True + break + if failure: + # pylint: disable=W0212 + ds.engine.datasets._reset_training_dataset(failure_point, (failure_point//dataset_size)) + failure = False + # let's collect the remaining rows of this epoch + if failure_point % dataset_size != 0: + for step, item in enumerate(itr): + res.append(item["scalars"][0]) + return res + + shard_id = 0 + expected = get_res(0, 5, -1) # no reset in this run + # try different failure points and compare against 'expected' + for failure_point in range(100): + expected2 = get_res(shard_id, num_repeats, failure_point) + np.testing.assert_array_equal(expected, expected2) + + ds.config.set_seed(original_seed) + ds.config.set_fast_recovery(original_fast_recovery) + + if __name__ == "__main__": test_reset_np() test_reset_cifar1() @@ -380,3 +581,7 @@ if __name__ == "__main__": test_reset_np_error() test_repeatable_reset_imagenet() test_repeatable_reset_distributed() + test_reset_shuffle() + test_reset_sampler(ds.RandomSampler()) + test_reset_batch(False) + test_reset_nonmappable()