[MD] fix sampler and shuffle in fast_recovery mode of reset

This commit is contained in:
mhmotallebi 2022-11-05 16:30:07 -04:00
parent ac2d982c16
commit 35ba3bcbaf
46 changed files with 509 additions and 176 deletions

View File

@ -25,7 +25,9 @@ namespace mindspore {
namespace dataset {
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer")
.def("Reset", [](TreeConsumer &self, int64_t step) { THROW_IF_ERROR(self.Reset(step)); });
.def("Reset", [](TreeConsumer &self, int64_t step, uint64_t epoch) {
THROW_IF_ERROR(self.Reset(step, epoch));
});
}));
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(

View File

@ -348,7 +348,7 @@ Status ToDevice::Terminate() {
return TreeConsumer::Terminate();
}
Status TreeConsumer::Reset(int64_t step) {
Status TreeConsumer::Reset(int64_t step, const int64_t epoch_num) {
MS_LOG(INFO) << "Resetting TreeConsumer";
MS_LOG(INFO) << "Terminating pipeline with UUID:" << tree_adapter_->tree_->GetUniqueId();
@ -374,7 +374,7 @@ Status TreeConsumer::Reset(int64_t step) {
}
#endif
tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step));
RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step, epoch_num));
RETURN_IF_NOT_OK(tree_adapter_->Launch());
MS_LOG(INFO) << "Launched a new pipeline after reset. UUID: " << tree_adapter_->tree_->GetUniqueId();
std::shared_ptr<DatasetOp> root2 = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());

View File

@ -60,8 +60,9 @@ class TreeConsumer {
/// Function to reset the current consumer to the provided step.
/// The consumer will terminate the pipeline and create a new one with skip injected.
/// \param step the step to reset the pipeline to.
/// \param epoch_num the epoch to reset the pipeline to.
/// \return Status error code
Status Reset(int64_t step);
Status Reset(int64_t step, const int64_t epoch_num);
/// Function to stop the consumer.
/// \return Status error code

View File

@ -83,8 +83,8 @@ Status BatchOp::operator()() {
RETURN_IF_NOT_OK(callback_manager_.Init(this));
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
int64_t epoch_num = 0, batch_num = 0, cnt = 0;
int64_t ep_step = 0, total_step = 0;
int64_t epoch_num = op_current_epochs_; // in failover reset this can be greater than zero
int64_t ep_step = 0, total_step = 0, batch_num = 0, cnt = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
TensorRow new_row;

View File

@ -40,12 +40,12 @@ Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheLookupOp::ResetSampler() { return Status::OK(); }
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
Status CacheLookupOp::ResetSampler(const bool failover_reset) { return Status::OK(); }
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
RETURN_UNEXPECTED_IF_NULL(op);
// We act like a sampler and as a dataset op. During handshake with leaf op,
// We must wait until the leaf op has indexed everything.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op));
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op, reset_count));
// Now we notify the main thread handshake has finished.
leaf_op_wp_.Set();
return Status::OK();

View File

@ -41,8 +41,8 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
// As a sampler, we override the following functions
Status ResetSampler() override;
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
Status ResetSampler(const bool failover_reset = false) override;
Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0) override;
Status InitSampler() override;
Status GetNextSample(TensorRow *out) override;
void Print(std::ostream &out, bool show_all) const override;

View File

@ -435,6 +435,15 @@ void DatasetOp::UpdateRepeatAndEpochCounter() {
MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_;
}
Status DatasetOp::SetEpoch(const int64_t epoch) {
CHECK_FAIL_RETURN_UNEXPECTED(epoch >= 0,
"New epoch value must be greater than or equal to 0, got: " + std::to_string(epoch));
while (op_current_epochs_ < epoch) {
UpdateRepeatAndEpochCounter();
}
return Status::OK();
}
int64_t DatasetOp::GetTreeBatchSize() {
if (child_.size() == 1) {
return child_[0]->GetTreeBatchSize();

View File

@ -197,7 +197,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// \brief During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// \notes Derived versions of this function should always call it's superclass version first
// \notes Derived versions of this function should always call their superclass version first
// before providing their own implementations.
virtual Status PrepareOperator();
@ -213,6 +213,11 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// \return T/F if this is an inlined operator
bool inlined() const { return (oc_queue_size_ == 0); }
// \brief Set the epoch number for op manually. This is only used in reset mode.
// \param[in] epoch The new epoch number to restart the pipeline from
// \return - Status
Status SetEpoch(const int64_t epoch);
// \brief Setter function, set the number of total repeats for the operator
void SetTotalRepeats(int32_t total_repeats) { op_total_repeats_ = total_repeats; }

View File

@ -47,6 +47,19 @@ ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_con
shuffle_last_row_idx_(0),
shuffle_buffer_state_(kShuffleStateInit) {}
Status ShuffleOp::PrepareOperator() {
// Run any common code from super class first before adding our own
RETURN_IF_NOT_OK(DatasetOp::PrepareOperator());
// in reset mode, we need to move forward the random generator seed.
if (GlobalContext::config_manager()->fast_recovery() && op_current_repeats_ > 0) {
for (auto i = 0; i < op_current_repeats_; i++) {
SelfReset();
}
}
return Status::OK();
}
// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
// itself rather than waiting for the reset driven from operators above it in the pipeline.
Status ShuffleOp::SelfReset() {
@ -54,11 +67,11 @@ Status ShuffleOp::SelfReset() {
// If reshuffle_each_epoch is false, then we always use the same seed for every
// epoch.
// If reshuffle_each_epoch is true, then the first epoch uses the given seed,
// and all subsequent epochs will then keep on using the rng_ without resetting it
if (!reshuffle_each_epoch_) {
rng_ = std::mt19937_64(shuffle_seed_);
// and we increment the seed by one in all subsequent epochs
if (reshuffle_each_epoch_) {
shuffle_seed_++;
}
rng_ = std::mt19937_64(shuffle_seed_);
shuffle_buffer_ = std::make_unique<TensorTable>();
shuffle_last_row_idx_ = 0;
shuffle_buffer_state_ = kShuffleStateInit;

View File

@ -87,6 +87,13 @@ class ShuffleOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kShuffleOp; }
// \brief During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// \notes Derived versions of this function should always call their superclass version first
// before providing their own implementations.
// @return Status The status code returned
Status PrepareOperator() override;
private:
// Private function to add a new row to the shuffle buffer.
// @return Status The status code returned

View File

@ -50,6 +50,9 @@ Status SkipOp::GetNextRow(TensorRow *row) {
bool eoe_received = false;
while (skip_count_ < max_skips_) {
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row));
if (row->eof()) {
return Status::OK();
}
if (row->eoe() && !once_only_) {
eoe_received = true;
break;

View File

@ -54,7 +54,10 @@ void GeneratorOp::Print(std::ostream &out, bool show_all) const {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status GeneratorOp::InitSampler() {
if (sampler_ != nullptr) {
return sampler_->HandshakeRandomAccessOp(this);
// Let the sampler know if we are resetting the pipeline to a specific epoch (op_current_repeats_ > 0)
// to mimic the behaviour in that state and have repeatability.
// Note that number of repeats is used since in each epoch we may reset sampler multiple times.
return sampler_->HandshakeRandomAccessOp(this, op_current_repeats_);
}
return Status::OK();
}

View File

@ -68,7 +68,8 @@ Status MappableLeafOp::operator()() {
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
while (true) { // each iteration is 1 epoch, breaks when IsLastIteration() is true
while (true) { // each iteration is 1 repeat (usually =1 epoch, unless we have a repeat node above us), breaks when
// IsLastIteration() is true
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
ep_step = 0;
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
@ -114,8 +115,10 @@ Status MappableLeafOp::Reset() {
// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status MappableLeafOp::InitSampler() {
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK();
// Let the sampler know if we are resetting the pipeline to a specific epoch (op_current_repeats_ > 0)
// to mimic the behaviour in that state and have repeatability.
// Note that number of repeats is used since in each epoch we may reset sampler multiple times.
return sampler_->HandshakeRandomAccessOp(this, op_current_repeats_);
}
// contains the main logic of pulling a IOBlock from IOBlockQueue, load a row and push the row to out_connector_

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
@ -39,7 +40,9 @@ NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connect
load_io_block_queue_(true),
shuffle_files_(shuffle_files),
num_rows_per_shard_(0),
num_rows_(0) {
num_rows_(0),
shuffled_keys_({}),
seed_(0) {
worker_connector_size_ = worker_connector_size;
}
@ -244,22 +247,15 @@ bool NonMappableLeafOp::NeedPushFileToBlockQueue(const std::string &file_name, i
return push;
}
void NonMappableLeafOp::ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
void NonMappableLeafOp::ShuffleKeys() {
std::mt19937 rng(num_devices_ == 1 ? GetSeed() : ++seed_);
std::shuffle(shuffled_keys_.begin(), shuffled_keys_.end(), rng);
}
Status NonMappableLeafOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
@ -269,9 +265,27 @@ Status NonMappableLeafOp::WaitToFillIOBlockQueue() {
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
ShuffleKeys();
}
RETURN_IF_NOT_OK(FillIOBlockQueue(shuffled_keys_));
}
return Status::OK();
}
Status NonMappableLeafOp::PrepareOperator() {
// Run any common code from super class first before adding our own
RETURN_IF_NOT_OK(DatasetOp::PrepareOperator());
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
shuffled_keys_.push_back(it.key());
}
// in reset mode, shuffled_keys needs to be ordered in the rsetting epoch
if (GlobalContext::config_manager()->fast_recovery() && op_current_repeats_ > 0) {
for (auto i = 0; i < op_current_repeats_; i++) {
ShuffleKeys();
}
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}

View File

@ -79,6 +79,13 @@ class NonMappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>
// @return Name of the current Op
std::string Name() const override { return "NonMappableLeafOp"; }
// \brief During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// \notes Derived versions of this function should always call their superclass version first
// before providing their own implementations.
// @return Status The status code returned
Status PrepareOperator() override;
protected:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
@ -135,7 +142,7 @@ class NonMappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>
// @return Status - the error code returned.
virtual Status CalculateNumRowsPerShard() = 0;
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed);
void ShuffleKeys();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
@ -159,6 +166,10 @@ class NonMappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>
bool shuffle_files_;
int64_t num_rows_per_shard_;
int64_t num_rows_;
private:
std::vector<int64_t> shuffled_keys_; // to store shuffled filename indices
uint32_t seed_; // used to shuffle filename indices
};
} // namespace dataset
} // namespace mindspore

View File

@ -157,8 +157,9 @@ Status DistributedSamplerRT::GetNextSample(TensorRow *out) {
return Status::OK();
}
Status DistributedSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_tensor_, "[Internal ERROR] Reset() Sampler called early or late.");
Status DistributedSamplerRT::ResetSampler(const bool failover_reset) {
CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || cnt_ == samples_per_tensor_,
"[Internal ERROR] ResetSampler() called early or late.");
cnt_ = 0;
if (shuffle_ == true) {
@ -168,7 +169,7 @@ Status DistributedSamplerRT::ResetSampler() {
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -55,9 +55,10 @@ class DistributedSamplerRT : public SamplerRT {
/// Init sampler, called by base class or python
Status InitSampler() override;
/// \brief for next epoch of sampleIds
/// \return Status code
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
int64_t GetDeviceID() { return device_id_; }

View File

@ -63,7 +63,7 @@ Status MindRecordSamplerRT::InitSampler() {
return Status::OK();
}
Status MindRecordSamplerRT::ResetSampler() {
Status MindRecordSamplerRT::ResetSampler(const bool failover_reset) {
// drive the shard reader reshuffle tasks to redo the sampling for another epoch
// Note that when cache is attached, this function is driven by cache lookup op rather than mindrecord op.
// Therefore, the reshuffle of tasks might happen in the middle of mindrecord's epoch

View File

@ -44,9 +44,10 @@ class MindRecordSamplerRT : public SamplerRT {
// meant to be called by base class or python
Status InitSampler() override;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
void SamplerPrint(std::ostream &out, bool show_all) const override;

View File

@ -105,22 +105,29 @@ Status PKSamplerRT::GetNextSample(TensorRow *out) {
return Status::OK();
}
Status PKSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late.");
Status PKSamplerRT::ResetSampler(const bool failover_reset) {
CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || next_id_ == num_samples_,
"[Internal ERROR] ResetSampler() called early or late.");
next_id_ = 0;
rnd_.seed(seed_++);
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();
}
Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
RETURN_UNEXPECTED_IF_NULL(op);
RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_));
RETURN_IF_NOT_OK(InitSampler());
// Move forward sampler's random generator if resetting the pipeline in fast_recovery mode
if (GlobalContext::config_manager()->fast_recovery()) {
for (auto i = 0; i < reset_count; i++) {
RETURN_IF_NOT_OK(ResetSampler(true)); // failover_reset = true
}
}
return Status::OK();
}

View File

@ -46,15 +46,17 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @param reset_count - reset the random generator these many times (used in fast_recovery mode of reset)
// @return
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0) override;
// init sampler, to be called by python or Handshake
Status InitSampler() override;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
// Printer for debugging purposes.
// @param out - output stream to write to

View File

@ -99,8 +99,11 @@ Status PythonSamplerRT::InitSampler() {
return Status::OK();
}
Status PythonSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "[Internal ERROR] Reset() Sampler called early or late.");
Status PythonSamplerRT::ResetSampler(const bool failover_reset) {
if (failover_reset) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "[Internal ERROR] ResetSampler() called early or late.");
need_to_reset_ = false;
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
@ -113,7 +116,7 @@ Status PythonSamplerRT::ResetSampler() {
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -40,9 +40,10 @@ class PythonSamplerRT : public SamplerRT {
// @return Status
Status InitSampler() override;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
// Op calls this to get next Sample that contains all the sampleIds
// @param TensorRow to be returned to corresponding Dataset Op

View File

@ -101,8 +101,9 @@ Status RandomSamplerRT::InitSampler() {
return Status::OK();
}
Status RandomSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late.");
Status RandomSamplerRT::ResetSampler(const bool failover_reset) {
CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || next_id_ == num_samples_,
"[Internal ERROR] ResetSampler() called early or late.");
next_id_ = 0;
if (reshuffle_each_epoch_) {
@ -116,7 +117,7 @@ Status RandomSamplerRT::ResetSampler() {
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -46,9 +46,10 @@ class RandomSamplerRT : public SamplerRT {
// meant to be called by base class or python
Status InitSampler() override;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
void SamplerPrint(std::ostream &out, bool show_all) const override;

View File

@ -41,7 +41,7 @@ SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_tensor)
col_desc_(nullptr),
is_initialized(false) {}
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
RETURN_UNEXPECTED_IF_NULL(op);
std::shared_ptr<SamplerRT> child_sampler;
if (HasChildSampler()) {
@ -52,7 +52,7 @@ Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
// Handshake and init child first.
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op, reset_count));
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "[Internal ERROR] RandomAccessOp init failed, as it is nullptr.");
@ -67,7 +67,12 @@ Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
// It's up to the derived class to check the validity of the two args
// Because some sampler only needs one of the arg (weighted_random_sampler)
RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
// Move forward sampler's random generator if resetting the pipeline in fast_recovery mode
if (GlobalContext::config_manager()->fast_recovery()) {
for (auto i = 0; i < reset_count; i++) {
RETURN_IF_NOT_OK(ResetSampler(true)); // failover_reset = true
}
}
return Status::OK();
}

View File

@ -91,15 +91,20 @@ class SamplerRT {
Status GetAllIdsThenReset(py::array *data);
#endif
// for next epoch of sampleIds
// @return Status The status code returned
virtual Status ResetSampler() = 0;
/// \brief Reset for next epoch.
/// \note If failover_reset is set, any override of this function must support the scenario where consecutive calls to
/// it are executed successfully (to prepare the sampler for a specific epoch, including updating any random
/// generator's internal state)
/// \param[in] failover_reset - A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
virtual Status ResetSampler(const bool failover_reset = false) = 0;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
// @param reset_count - reset the random generator these many times (used in fast_recovery mode of reset)
// @return status error code
virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0);
// initialize sampler and perform checks on certain vars
virtual Status InitSampler() { return Status::OK(); }

View File

@ -93,13 +93,14 @@ Status SequentialSamplerRT::InitSampler() {
return Status::OK();
}
Status SequentialSamplerRT::ResetSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late.");
Status SequentialSamplerRT::ResetSampler(const bool failover_reset) {
CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || id_count_ == num_samples_,
"[Internal ERROR] ResetSampler() called early or late.");
current_id_ = start_index_;
id_count_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -39,9 +39,10 @@ class SequentialSamplerRT : public SamplerRT {
// init sampler, called by python
Status InitSampler() override;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
// Op calls this to get next Sample that contains all the sampleIds
// @param TensorRow to be returned to corresponding Dataset Op

View File

@ -19,26 +19,30 @@
namespace mindspore {
namespace dataset {
Status SkipFirstEpochSamplerRT::ResetSampler() {
if (id_count_ != num_samples_) {
std::string err_msg =
"[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) +
" num_samples_: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
current_id_ = 0;
id_count_ = 0;
Status SkipFirstEpochSamplerRT::ResetSampler(const bool failover_reset) {
// This is a special sampler for Failover Reset, its internal state should
// not reset when failover_reset is set to true.
if (!failover_reset) {
if (id_count_ != num_samples_) {
std::string err_msg =
"[Internal ERROR] ResetSampler() called early or late. id_count_: " + std::to_string(id_count_) +
" num_samples_: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
current_id_ = 0;
id_count_ = 0;
if (!first_epoch_done_) {
num_samples_ += start_index_;
start_index_ = 0;
samples_per_tensor_ = num_samples_;
first_epoch_done_ = true;
if (!first_epoch_done_) {
num_samples_ += start_index_;
start_index_ = 0;
samples_per_tensor_ = num_samples_;
first_epoch_done_ = true;
}
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -30,9 +30,10 @@ class SkipFirstEpochSamplerRT : public SequentialSamplerRT {
// Destructor.
~SkipFirstEpochSamplerRT() = default;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
/// \brief Reset for next epoch.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
/// \brief Gets the number of samples available
/// \note Since this sampler returns different number of samples in the first epoch (compared to other epochs), this

View File

@ -46,12 +46,12 @@ Status SubsetRandomSamplerRT::InitSampler() {
}
// Reset the internal variable to the initial state.
Status SubsetRandomSamplerRT::ResetSampler() {
Status SubsetRandomSamplerRT::ResetSampler(const bool failover_reset) {
// Randomized the indices again.
rand_gen_.seed(GetSeed());
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
return SubsetSamplerRT::ResetSampler();
return SubsetSamplerRT::ResetSampler(failover_reset);
}
void SubsetRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {

View File

@ -43,9 +43,10 @@ class SubsetRandomSamplerRT : public SubsetSamplerRT {
/// \return Status
Status InitSampler() override;
/// Reset the internal variable to the initial state and reshuffle the indices.
/// \return Status
Status ResetSampler() override;
/// \brief Reset the internal variable(s) to the initial state and reshuffle the indices.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
/// Printer for debugging purposes.
/// \param out - output stream to write to

View File

@ -48,12 +48,12 @@ Status SubsetSamplerRT::InitSampler() {
}
// Reset the internal variable to the initial state.
Status SubsetSamplerRT::ResetSampler() {
Status SubsetSamplerRT::ResetSampler(const bool failover_reset) {
// Reset the internal counters.
sample_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -42,9 +42,10 @@ class SubsetSamplerRT : public SamplerRT {
/// \return Status
Status InitSampler() override;
/// Reset the internal variable to the initial state and reshuffle the indices.
/// \return Status
Status ResetSampler() override;
/// \brief Reset the internal variable(s) to the initial state and reshuffle the indices.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
/// Get the sample ids.
/// \param[out] TensorRow where the sample ids will be placed.

View File

@ -94,7 +94,7 @@ void WeightedRandomSamplerRT::InitOnePassSampling() {
}
// Reset the internal variable to the initial state and reshuffle the indices.
Status WeightedRandomSamplerRT::ResetSampler() {
Status WeightedRandomSamplerRT::ResetSampler(const bool failover_reset) {
sample_id_ = 0;
rand_gen_.seed(GetSeed());
if (!replacement_) {
@ -104,7 +104,7 @@ Status WeightedRandomSamplerRT::ResetSampler() {
}
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->ResetSampler());
RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
}
return Status::OK();

View File

@ -45,8 +45,10 @@ class WeightedRandomSamplerRT : public SamplerRT {
// @return Status
Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices.
Status ResetSampler() override;
/// \brief Reset the internal variable(s) to the initial state and reshuffle the indices.
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
/// \return Status The status code returned
Status ResetSampler(const bool failover_reset = false) override;
// Get the sample ids.
// @param[out] TensorRow where the sample ids will be placed.

View File

@ -87,7 +87,8 @@ Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Failed to inject SkipOp.");
int64_t dataset_size = -1;
RETURN_IF_NOT_OK(root_ir->GetDatasetSize(nullptr, false, &dataset_size));
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
RETURN_IF_NOT_OK(root_ir->GetDatasetSize(size_getter, false, &dataset_size));
CHECK_FAIL_RETURN_UNEXPECTED(dataset_size > 0, "Cannot reset the pipeline, dataset size is undefined");
int32_t num_epochs = finder.GetNumEpochs();
int64_t step = finder.GetStep();
@ -105,11 +106,7 @@ Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const
}
// in fast recovery, we start from current epoch and skip remaining steps (skip node will also be pushed down)
if (GlobalContext::config_manager()->fast_recovery()) {
int32_t new_num_epochs = num_epochs - static_cast<int32_t>(step / dataset_size);
int64_t skip_num = step % dataset_size;
root_ir->SetNumEpochs(new_num_epochs);
auto skip_node = std::make_shared<SkipNode>(skip_num);
skip_node->SetOnceOnly(true);
RETURN_IF_NOT_OK(node->InsertAbove(skip_node));

View File

@ -170,7 +170,7 @@ Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std
return Status::OK();
}
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int64_t epoch_num) {
RETURN_UNEXPECTED_IF_NULL(root_ir);
// Create ExecutionTree
tree_ = std::make_unique<ExecutionTree>();
@ -180,6 +180,10 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
RETURN_IF_NOT_OK(BuildExecutionTreeRecur(root_ir->Children()[0], &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
if (usage_ == kDeReset) {
RETURN_IF_NOT_OK(AdjustReset(epoch_num));
}
// Prepare the tree
RETURN_IF_NOT_OK(tree_->Prepare());
@ -188,7 +192,8 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
return Status::OK();
}
Status TreeAdapter::Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_t num_epochs, int64_t step) {
Status TreeAdapter::Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_t num_epochs, int64_t step,
const int64_t epoch_num) {
RETURN_UNEXPECTED_IF_NULL(input_ir);
input_ir_ = input_ir;
tree_state_ = kCompileStateIRGraphBuilt;
@ -227,11 +232,21 @@ Status TreeAdapter::Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_
// Remember the root node
root_ir_ = root_ir;
RETURN_IF_NOT_OK(Build(root_ir_));
RETURN_IF_NOT_OK(Build(root_ir_, epoch_num));
tree_state_ = kCompileStateReady;
return Status::OK();
}
Status TreeAdapter::AdjustReset(const int64_t epoch_num) {
if (GlobalContext::config_manager()->fast_recovery() && epoch_num > 0) {
MS_LOG(INFO) << "Adjusting dataset pipeline for failover reset to start on epoch: " << (epoch_num + 1);
for (auto op = tree_->begin(); op != tree_->end(); op++) {
op->SetEpoch(epoch_num);
}
}
return Status::OK();
}
Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_UNEXPECTED_IF_NULL(row);

View File

@ -57,7 +57,8 @@ class TreeAdapter {
// This function performs syntax checking, semantics checking, optimizes, and then builds
// the Execution tree.
Status Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_t num_epochs = -1, int64_t step = 0);
Status Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_t num_epochs = -1, int64_t step = 0,
const int64_t epoch_num = 0);
// Return the root node of the IR after cloned from the parsed IR tree
std::shared_ptr<DatasetNode> RootIRNode() const { return root_ir_; }
@ -115,11 +116,14 @@ class TreeAdapter {
Status PostPass(std::shared_ptr<DatasetNode> ir);
// Build an Execution tree
Status Build(std::shared_ptr<DatasetNode> root_ir);
Status Build(std::shared_ptr<DatasetNode> root_ir, const int64_t epoch_num = 0);
// This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree.
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
// Adjust the pipeline (eg, move rng_ forward) if in reset mode
Status AdjustReset(const int64_t epoch_num);
std::unordered_map<std::string, int32_t> column_name_map_;
std::shared_ptr<DatasetNode> input_ir_;
std::shared_ptr<DatasetNode> root_ir_;

View File

@ -115,16 +115,17 @@ def _get_training_dataset():
return _train_dataset
def _reset_training_dataset(step):
def _reset_training_dataset(step, epoch):
"""
Reset the training dataset to the given step number.
Reset the training dataset to the given step and epoch number.
Args:
step (int): Global step number.
epoch (int): Global epoch number
"""
dataset = _get_training_dataset()
if dataset is not None:
dataset._reset(step) # pylint: disable=W0212
dataset._reset(step, epoch) # pylint: disable=W0212
else:
raise RuntimeError("Training dataset is not set.")

View File

@ -169,14 +169,15 @@ class Iterator:
self._col_names = self.__ori_dataset.get_col_names()
return self._col_names
def _reset(self, step):
def _reset(self, step, epoch):
"""
Reset the iterator to the given step number.
Reset the iterator to the given step number and epoch number.
Args:
step (int): Global step number.
step (int): Global step number
epoch (int): Global epoch number
"""
self._iterator.Reset(step)
self._iterator.Reset(step, epoch)
def _transform_md_to_output(self, t):
if self._output_numpy:

View File

@ -829,7 +829,7 @@ class Model:
os.remove(cb_params.latest_ckpt_file)
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
+ cb_params.latest_ckpt_file) from e
_reset_training_dataset(cb_params.cur_step_num)
_reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num)
self.need_load_ckpt = False
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
@ -858,9 +858,9 @@ class Model:
self.epoch_iter = recovery_epoch_num
cb_params.cur_epoch_num = self.epoch_iter + 1
cb_params.last_save_ckpt_step = cb_params.cur_step_num
_reset_training_dataset(cb_params.cur_step_num)
_reset_training_dataset(cb_params.cur_step_num, cb_params.cur_epoch_num)
else:
_reset_training_dataset(0)
_reset_training_dataset(0, 0)
_set_recovery_context(need_reset=False)

View File

@ -183,7 +183,7 @@ TEST_F(MindDataTestPipeline, TestCallShuffleTwice) {
uint32_t original_seed = config::get_seed();
uint32_t original_num_parallel_workers = config::get_num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
config::set_seed(654);
config::set_seed(655); // not all seeds satisfy the assertions in this test.
config::set_num_parallel_workers(1);
// Create a TextFile Dataset with single text file which has three samples

View File

@ -24,6 +24,10 @@ from util_minddataset import add_and_remove_cv_file
# pylint: disable=no-value-for-parameter
# Need to run all these tests in separate processes since MD internally stores
# "training" dataset in a global variable every time.
pytestmark = pytest.mark.forked
def create_np_dataset(size):
dimensions = (size, 4, 3, 2)
@ -77,14 +81,13 @@ def create_random_imagenet_dataset(repeat_size, sampler=None, num_parallel_worke
data = data.repeat(repeat_size)
crop_op1 = vision.RandomCrop(4)
operations = [vision.Decode(to_pil=to_pil), crop_op1]
if to_pil: # include a pyfunc in test if to_pil is True
if to_pil: # include a pyfunc in test if to_pil is True
operations.append(lambda x: x.rotate(45))
data = data.map(operations=operations, input_columns=[
"image"], num_parallel_workers=num_parallel_workers, python_multiprocessing=True)
if batch_func:
data = data.batch(
batch_size=2, per_batch_map=batch_func,
num_parallel_workers=num_parallel_workers, python_multiprocessing=True)
data = data.batch(batch_size=2, per_batch_map=batch_func, input_columns=["label"],
num_parallel_workers=num_parallel_workers, python_multiprocessing=True)
data = data.project(["image"])
return data
@ -98,7 +101,7 @@ def create_minddata_dataset(size):
return data
def run_reset(data, num_epochs, failure_point: int, reset_step: int):
def run_reset(data, num_epochs: int, failure_point: int):
size = data.get_dataset_size()
expected = []
expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
@ -107,50 +110,51 @@ def run_reset(data, num_epochs, failure_point: int, reset_step: int):
expected.append(d)
del expected_itr
actual_before_reset = []
expected2 = []
itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
cur_step: int = 0
failed = False
for _ in range(num_epochs):
for d in itr:
actual_before_reset.append(d)
if cur_step == failure_point:
ds.engine.datasets._reset_training_dataset(reset_step) # pylint: disable=W0212
expected2.append(d)
if cur_step + 1 == failure_point:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point // size)
failed = True
break
cur_step += 1
if failed:
break
actual_after_reset = []
if failed:
for _ in range(reset_step // size, num_epochs):
for _ in range(failure_point // size, num_epochs):
for d in itr:
actual_after_reset.append(d)
expected2.append(d)
with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
for _ in itr:
pass
expected2.append(d)
for x, y in zip(expected[:failure_point], actual_before_reset):
np.testing.assert_array_equal(x, y)
for x, y in zip(expected[reset_step:], actual_after_reset):
assert len(expected) == len(expected2)
for x, y in zip(expected, expected2):
np.testing.assert_array_equal(x, y)
def run_reset_error(data, num_epochs: int, failure_point: int):
itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) # pylint: disable=unused-variable
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
dataset_size = data.get_dataset_size()
if failure_point > 0:
with pytest.raises(RuntimeError) as err:
ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size)
assert "Cannot reset the pipeline, reset step must be less than dataset_size * num_epochs." in str(err.value)
else:
with pytest.raises(RuntimeError) as err:
ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point % dataset_size)
assert "Cannot reset the pipeline, reset step must be >= 0." in str(err.value)
@ -165,8 +169,7 @@ def test_reset_np():
failure_steps = (dataset_size * num_epochs) // 10
data = create_np_dataset(size=dataset_size)
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
run_reset(data, num_epochs=num_epochs, failure_point=failure_point)
def test_reset_cifar1():
@ -180,8 +183,7 @@ def test_reset_cifar1():
failure_steps = (dataset_size * num_epochs) // 5
data = create_cifar_dataset1(size=dataset_size)
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
run_reset(data, num_epochs=num_epochs, failure_point=failure_point)
def test_reset_cifar2():
@ -195,8 +197,7 @@ def test_reset_cifar2():
failure_steps = (dataset_size * num_epochs) // 5
data = create_cifar_dataset2(size=dataset_size)
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
run_reset(data, num_epochs=num_epochs, failure_point=failure_point)
def test_reset_imagenet():
@ -210,8 +211,7 @@ def test_reset_imagenet():
failure_steps = (dataset_size * num_epochs) // 4
data = create_imagenet_dataset(size=dataset_size)
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
run_reset(data, num_epochs=num_epochs, failure_point=failure_point)
def test_reset_mindrecord(add_and_remove_cv_file): # pylint: disable=unused-argument, redefined-outer-name
@ -225,8 +225,7 @@ def test_reset_mindrecord(add_and_remove_cv_file): # pylint: disable=unused-arg
failure_steps = (dataset_size * num_epochs) // 10
data = create_minddata_dataset(size=dataset_size)
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
run_reset(data, num_epochs=num_epochs, failure_point=failure_point)
def test_reset_np_error():
@ -243,13 +242,13 @@ def test_reset_np_error():
run_reset_error(data, num_epochs=num_epochs, failure_point=failure_point)
def random_col(col1, col2, batch_info):
return ([np.random.rand(1) for a in col1], [np.random.rand(1) for b in col2])
def random_col(col1, batch_info):
return ([np.random.rand(1) for a in col1],)
@pytest.mark.parametrize("num_parallel_workers", (4, 5))
@pytest.mark.parametrize("sampler", (ds.RandomSampler(), None))
@pytest.mark.parametrize("to_pil, batch_func", [(False, None), (True, random_col)]) # test C ops and Python ops (MP)
@pytest.mark.parametrize("to_pil, batch_func", [(False, None), (True, random_col)]) # test C ops and Python ops (MP)
def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_func):
"""
Feature: Dataset recovery
@ -277,7 +276,7 @@ def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_
del expected_itr
dataset_size = data.get_dataset_size()
# try different failure points
for failure_point in (5, 6, 19, 22):
for failure_point in (5, 6, 22):
expected2 = []
expected2_itr = data.create_tuple_iterator(
num_epochs=num_epochs, output_numpy=True)
@ -291,23 +290,22 @@ def test_repeatable_reset_imagenet(sampler, num_parallel_workers, to_pil, batch_
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, failure_point // dataset_size)
failure = False
for d in expected2_itr:
expected2.append(d)
del expected2_itr
# verify count and values of failover with original run
assert len(expected) == len(expected2)
for a, b in zip(expected, expected2):
assert np.array_equal(a[0], b[0])
np.testing.assert_array_equal(expected, expected2)
ds.config.set_seed(original_seed)
ds.config.set_fast_recovery(original_fast_recovery)
ds.config.set_enable_shared_mem(original_shared_mem)
@pytest.mark.parametrize("to_pil", (False, True)) # test C ops and Python ops with MP=true
@pytest.mark.parametrize("to_pil", (False, True)) # test C ops and Python ops with MP=true
@pytest.mark.parametrize("num_parallel_workers", (4, 5))
@pytest.mark.parametrize("shard_id", (0, 1, 2, 3))
def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil):
@ -345,8 +343,7 @@ def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil):
# try different failure points
for failure_point in (3, 7, 9):
expected2 = []
expected2_itr = data.create_tuple_iterator(
num_epochs=num_epochs, output_numpy=True)
expected2_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(expected2_itr) # pylint: disable=W0212
failure = False
for epoch in range(num_epochs):
@ -356,21 +353,225 @@ def test_repeatable_reset_distributed(shard_id, num_parallel_workers, to_pil):
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
failure = False
for d in expected2_itr:
expected2.append(d)
# verify count and values of failover with original run
assert len(expected) == len(expected2)
for a, b in zip(expected, expected2):
assert np.array_equal(a, b)
np.testing.assert_array_equal(expected, expected2)
ds.config.set_seed(original_seed)
ds.config.set_fast_recovery(original_fast_recovery)
ds.config.set_enable_shared_mem(original_shared_mem)
def test_reset_shuffle():
"""
Feature: Dataset recovery
Description: The random generator in shuffle operation resets to correct internal state
Expectation: Same dataset after reset
"""
original_seed = ds.config.get_seed()
original_fast_recovery = ds.config.get_fast_recovery()
ds.config.set_seed(1)
ds.config.set_fast_recovery(True)
source = [(np.array([x])) for x in range(10)]
data1 = ds.NumpySlicesDataset(source, ["data"], sampler=ds.SequentialSampler())
data1 = data1.shuffle(3)
data1 = data1.skip(1)
num_epochs = 3
expected = []
expected_itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
for epoch in range(num_epochs):
for step, d in enumerate(expected_itr):
expected.append(d)
failure_point = 13
expected2 = []
expected2_itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(expected2_itr) # pylint: disable=W0212
failure = False
for epoch in range(num_epochs):
for step, d in enumerate(expected2_itr):
expected2.append(d)
if epoch * data1.get_dataset_size() + step + 1 == failure_point:
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
failure = False
for step, d in enumerate(expected2_itr):
expected2.append(d)
with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
for step, d in enumerate(expected2_itr):
expected2.append(d)
np.testing.assert_array_equal(expected, expected2)
ds.config.set_seed(original_seed)
ds.config.set_fast_recovery(original_fast_recovery)
@pytest.mark.parametrize("sampler", (ds.RandomSampler(), ds.SequentialSampler()))
def test_reset_sampler(sampler):
"""
Feature: Dataset recovery
Description: The samplers for source operations reset to correct internal state.
Expectation: Same dataset after reset
"""
original_seed = ds.config.get_seed()
original_fast_recovery = ds.config.get_fast_recovery()
ds.config.set_seed(1)
ds.config.set_fast_recovery(True)
source = [(np.array([x]),) for x in range(10)]
data1 = ds.NumpySlicesDataset(source, ["data"], sampler=sampler)
data1 = data1.skip(1)
data1 = data1.repeat(2)
data1 = data1.skip(1)
num_epochs = 3
expected_itr = data1.create_tuple_iterator(
num_epochs=num_epochs, output_numpy=True)
expected = []
for epoch in range(num_epochs):
for step, d in enumerate(expected_itr):
expected.append(d)
failure_point = 13
expected2 = []
expected2_itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(expected2_itr) # pylint: disable=W0212
failure = False
for epoch in range(num_epochs):
for step, d in enumerate(expected2_itr):
expected2.append(d)
if epoch * data1.get_dataset_size() + step + 1 == failure_point:
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
failure = False
for step, d in enumerate(expected2_itr):
expected2.append(d)
with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
for step, d in enumerate(expected2_itr):
expected2.append(d)
np.testing.assert_array_equal(expected, expected2)
ds.config.set_seed(original_seed)
ds.config.set_fast_recovery(original_fast_recovery)
@pytest.mark.parametrize("fast_recovery", (False, True))
def test_reset_batch(fast_recovery):
"""
Feature: Dataset recovery
Description: The BatchInfo argument of batch operation contains correct information (epoch num)
Expectation: Test succeeds
"""
original_fast_recovery = ds.config.get_fast_recovery()
ds.config.set_fast_recovery(fast_recovery)
num_epochs = 5
repeat_size = 4
skip_size = 12
def get_epoch_num(col1, batch_info):
return (np.array(batch_info.get_epoch_num()),)
data1 = ds.NumpySlicesDataset(np.arange(10).reshape(10, 1))
data1 = data1.repeat(repeat_size)
data1 = data1.skip(skip_size)
data1 = data1.batch(batch_size=7, per_batch_map=get_epoch_num, num_parallel_workers=1, python_multiprocessing=False)
itr = data1.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
failure = False
failure_point = 25
expected = np.repeat(np.arange(5), 4).reshape((20, 1))
expected2 = []
for epoch in range(num_epochs):
for step, d in enumerate(itr):
expected2.append(d)
if epoch * data1.get_dataset_size() + step + 1 == failure_point:
failure = True
break
if failure:
ds.engine.datasets._reset_training_dataset(failure_point, epoch) # pylint: disable=W0212
failure = False
for step, d in enumerate(itr):
expected2.append(d)
with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
for d in itr:
expected2.append(d)
np.testing.assert_array_equal(expected, expected2)
ds.config.set_fast_recovery(original_fast_recovery)
def test_reset_nonmappable():
"""
Feature: Dataset recovery
Description: The order of rows read in normal and reset runs are identical for a TFRecord dataset.
Expectation: Test succeeds
"""
original_seed = ds.config.get_seed()
original_fast_recovery = ds.config.get_fast_recovery()
num_epochs = 10
num_repeats = 5
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
# run a pipeline and collect rows
def get_res(shard_id, num_repeats, failure_point):
ds.config.set_seed(1)
ds.config.set_fast_recovery(True)
data1 = ds.TFRecordDataset(tf_files, num_shards=4, shard_id=shard_id, num_samples=5, shuffle=ds.Shuffle.FILES)
data1 = data1.repeat(num_repeats)
itr = data1.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
dataset_size = data1.get_dataset_size()
res = list()
failure = False
for epoch in range(num_epochs):
for step, item in enumerate(itr):
res.append(item["scalars"][0])
if epoch * dataset_size + step + 1 == failure_point:
failure = True
break
if failure:
# pylint: disable=W0212
ds.engine.datasets._reset_training_dataset(failure_point, (failure_point//dataset_size))
failure = False
# let's collect the remaining rows of this epoch
if failure_point % dataset_size != 0:
for step, item in enumerate(itr):
res.append(item["scalars"][0])
return res
shard_id = 0
expected = get_res(0, 5, -1) # no reset in this run
# try different failure points and compare against 'expected'
for failure_point in range(100):
expected2 = get_res(shard_id, num_repeats, failure_point)
np.testing.assert_array_equal(expected, expected2)
ds.config.set_seed(original_seed)
ds.config.set_fast_recovery(original_fast_recovery)
if __name__ == "__main__":
test_reset_np()
test_reset_cifar1()
@ -380,3 +581,7 @@ if __name__ == "__main__":
test_reset_np_error()
test_repeatable_reset_imagenet()
test_repeatable_reset_distributed()
test_reset_shuffle()
test_reset_sampler(ds.RandomSampler())
test_reset_batch(False)
test_reset_nonmappable()