forked from mindspore-Ecosystem/mindspore
!46971 [MD] fix code check warnings
Merge pull request !46971 from Mohammad Motallebi/fix_code_check_Dec
This commit is contained in:
commit
50fb77d84f
|
@ -54,7 +54,7 @@ Status TensorRow::Clone(TensorRow *new_tr) const {
|
||||||
for (const std::shared_ptr<Tensor> &s : row_) {
|
for (const std::shared_ptr<Tensor> &s : row_) {
|
||||||
std::shared_ptr<Tensor> d;
|
std::shared_ptr<Tensor> d;
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(s, &d));
|
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(s, &d));
|
||||||
new_tr->row_.emplace_back(std::move(d));
|
(void)new_tr->row_.emplace_back(std::move(d));
|
||||||
}
|
}
|
||||||
new_tr->id_ = id_;
|
new_tr->id_ = id_;
|
||||||
new_tr->path_ = path_;
|
new_tr->path_ = path_;
|
||||||
|
|
|
@ -40,7 +40,7 @@ Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
|
||||||
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
|
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
Status CacheLookupOp::ResetSampler(const bool failover_reset) { return Status::OK(); }
|
Status CacheLookupOp::ResetSampler([[maybe_unused]] const bool failover_reset) { return Status::OK(); }
|
||||||
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
|
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
|
||||||
RETURN_UNEXPECTED_IF_NULL(op);
|
RETURN_UNEXPECTED_IF_NULL(op);
|
||||||
// We act like a sampler and as a dataset op. During handshake with leaf op,
|
// We act like a sampler and as a dataset op. During handshake with leaf op,
|
||||||
|
|
|
@ -41,8 +41,8 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
|
||||||
Status operator()() override;
|
Status operator()() override;
|
||||||
Status WorkerEntry(int32_t worker_id) override;
|
Status WorkerEntry(int32_t worker_id) override;
|
||||||
// As a sampler, we override the following functions
|
// As a sampler, we override the following functions
|
||||||
Status ResetSampler(const bool failover_reset = false) override;
|
Status ResetSampler(const bool failover_reset) override;
|
||||||
Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0) override;
|
Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) override;
|
||||||
Status InitSampler() override;
|
Status InitSampler() override;
|
||||||
Status GetNextSample(TensorRow *out) override;
|
Status GetNextSample(TensorRow *out) override;
|
||||||
void Print(std::ostream &out, bool show_all) const override;
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
|
|
@ -49,8 +49,8 @@ class ParallelOp : public DatasetOp {
|
||||||
epoch_sync_flag_(false),
|
epoch_sync_flag_(false),
|
||||||
num_workers_(num_workers),
|
num_workers_(num_workers),
|
||||||
next_worker_id_(0),
|
next_worker_id_(0),
|
||||||
strategy_{nullptr},
|
worker_connector_size_(op_connector_size),
|
||||||
worker_connector_size_(op_connector_size) {
|
strategy_{nullptr} {
|
||||||
// reduce excessive memory usage with high parallelism
|
// reduce excessive memory usage with high parallelism
|
||||||
constexpr int32_t worker_limit = 4;
|
constexpr int32_t worker_limit = 4;
|
||||||
if (num_workers_ > worker_limit) {
|
if (num_workers_ > worker_limit) {
|
||||||
|
@ -168,17 +168,18 @@ class ParallelOp : public DatasetOp {
|
||||||
class RowHandlingStrategy {
|
class RowHandlingStrategy {
|
||||||
public:
|
public:
|
||||||
explicit RowHandlingStrategy(ParallelOp *op) : op_(op) {}
|
explicit RowHandlingStrategy(ParallelOp *op) : op_(op) {}
|
||||||
|
virtual ~RowHandlingStrategy() = default;
|
||||||
|
|
||||||
virtual Status HandleHealthyRow(TensorRow *row) {
|
virtual Status HandleHealthyRow([[maybe_unused]] TensorRow *row) {
|
||||||
++this->op_->ep_step_;
|
++this->op_->ep_step_;
|
||||||
++this->op_->total_step_;
|
++this->op_->total_step_;
|
||||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
|
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam(
|
||||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||||
return this->op_->out_connector_->Add(std::move(*row));
|
return this->op_->out_connector_->Add(std::move(*row));
|
||||||
}
|
}
|
||||||
virtual Status HandleErrorRow(TensorRow *row) = 0;
|
virtual Status HandleErrorRow([[maybe_unused]] TensorRow *row) = 0;
|
||||||
|
|
||||||
virtual Status HandleEOE(TensorRow *row) {
|
virtual Status HandleEOE([[maybe_unused]] TensorRow *row) {
|
||||||
this->op_->current_repeats_++;
|
this->op_->current_repeats_++;
|
||||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||||
if (this->op_->current_repeats_ % this->op_->GetOpNumRepeatsPerEpoch() == 0) {
|
if (this->op_->current_repeats_ % this->op_->GetOpNumRepeatsPerEpoch() == 0) {
|
||||||
|
@ -189,9 +190,9 @@ class ParallelOp : public DatasetOp {
|
||||||
}
|
}
|
||||||
return op_->out_connector_->Add(std::move(*row));
|
return op_->out_connector_->Add(std::move(*row));
|
||||||
}
|
}
|
||||||
virtual Status HandleEOF(TensorRow *row) {
|
virtual Status HandleEOF([[maybe_unused]] TensorRow *row) {
|
||||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.End(
|
RETURN_IF_NOT_OK(this->op_->callback_manager_.End(CallbackParam(
|
||||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||||
return op_->out_connector_->Add(std::move(*row));
|
return op_->out_connector_->Add(std::move(*row));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,7 +203,7 @@ class ParallelOp : public DatasetOp {
|
||||||
class ErrorStrategy : public RowHandlingStrategy {
|
class ErrorStrategy : public RowHandlingStrategy {
|
||||||
public:
|
public:
|
||||||
using RowHandlingStrategy::RowHandlingStrategy;
|
using RowHandlingStrategy::RowHandlingStrategy;
|
||||||
Status HandleErrorRow(TensorRow *row) override {
|
Status HandleErrorRow([[maybe_unused]] TensorRow *row) override {
|
||||||
return Status(StatusCode::kMDUnexpectedError,
|
return Status(StatusCode::kMDUnexpectedError,
|
||||||
"[Internal Error] Error row is detected in collector while Error strategy is set to error out!");
|
"[Internal Error] Error row is detected in collector while Error strategy is set to error out!");
|
||||||
}
|
}
|
||||||
|
@ -211,14 +212,14 @@ class ParallelOp : public DatasetOp {
|
||||||
class SkipStrategy : public RowHandlingStrategy {
|
class SkipStrategy : public RowHandlingStrategy {
|
||||||
public:
|
public:
|
||||||
using RowHandlingStrategy::RowHandlingStrategy;
|
using RowHandlingStrategy::RowHandlingStrategy;
|
||||||
Status HandleErrorRow(TensorRow *row) override { return Status::OK(); }
|
Status HandleErrorRow([[maybe_unused]] TensorRow *row) override { return Status::OK(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
class ReplaceStrategy : public RowHandlingStrategy {
|
class ReplaceStrategy : public RowHandlingStrategy {
|
||||||
public:
|
public:
|
||||||
using RowHandlingStrategy::RowHandlingStrategy;
|
using RowHandlingStrategy::RowHandlingStrategy;
|
||||||
|
|
||||||
Status HandleHealthyRow(TensorRow *row) override {
|
Status HandleHealthyRow([[maybe_unused]] TensorRow *row) override {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
|
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
|
||||||
"[Internal Error] Number of cached rows is beyond the number set.");
|
"[Internal Error] Number of cached rows is beyond the number set.");
|
||||||
if (backup_index_ < kCachedRowsSize - 1) { // cache has used row(s) or is not full
|
if (backup_index_ < kCachedRowsSize - 1) { // cache has used row(s) or is not full
|
||||||
|
@ -239,12 +240,12 @@ class ParallelOp : public DatasetOp {
|
||||||
// send the healthy row to next op
|
// send the healthy row to next op
|
||||||
++this->op_->ep_step_;
|
++this->op_->ep_step_;
|
||||||
++this->op_->total_step_;
|
++this->op_->total_step_;
|
||||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
|
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam(
|
||||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||||
return this->op_->out_connector_->Add(std::move(*row));
|
return this->op_->out_connector_->Add(std::move(*row));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HandleErrorRow(TensorRow *row) override {
|
Status HandleErrorRow([[maybe_unused]] TensorRow *row) override {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
|
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
|
||||||
"[Internal Error] Number of cached rows is beyond the number set.");
|
"[Internal Error] Number of cached rows is beyond the number set.");
|
||||||
// cache is not full of unused rows
|
// cache is not full of unused rows
|
||||||
|
@ -256,7 +257,7 @@ class ParallelOp : public DatasetOp {
|
||||||
return AddFromCache();
|
return AddFromCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HandleEOE(TensorRow *row) override {
|
Status HandleEOE([[maybe_unused]] TensorRow *row) override {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(missing_errors_ == 0 || !IsCacheEmpty(),
|
CHECK_FAIL_RETURN_UNEXPECTED(missing_errors_ == 0 || !IsCacheEmpty(),
|
||||||
"All data is garbage and cannot be replaced.");
|
"All data is garbage and cannot be replaced.");
|
||||||
// send outstanding rows first and then send eoe
|
// send outstanding rows first and then send eoe
|
||||||
|
@ -267,19 +268,23 @@ class ParallelOp : public DatasetOp {
|
||||||
return RowHandlingStrategy::HandleEOE(row);
|
return RowHandlingStrategy::HandleEOE(row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status HandleEOF([[maybe_unused]] TensorRow *row) override {
|
||||||
|
// release memory
|
||||||
|
std::deque<TensorRow>().swap(backup_rows);
|
||||||
|
return RowHandlingStrategy::HandleEOF(row);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status AddFromCache() {
|
Status AddFromCache() {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() > 0, "Cannot add a row from cache since cache is empty!");
|
CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() > 0, "Cannot add a row from cache since cache is empty!");
|
||||||
// Note: If backup_index_ is negative (error samples at the end of data),
|
const TensorRow &cached_row = backup_rows[static_cast<size_t>(backup_index_) % backup_rows.size()];
|
||||||
// the modulo division with size_t will be result in 0, and backup_rows[0] accessed.
|
|
||||||
const TensorRow &cached_row = backup_rows[backup_index_ % backup_rows.size()];
|
|
||||||
TensorRow copy_row;
|
TensorRow copy_row;
|
||||||
RETURN_IF_NOT_OK(cached_row.Clone(©_row));
|
RETURN_IF_NOT_OK(cached_row.Clone(©_row));
|
||||||
backup_index_--;
|
backup_index_--;
|
||||||
++this->op_->ep_step_;
|
++this->op_->ep_step_;
|
||||||
++this->op_->total_step_;
|
++this->op_->total_step_;
|
||||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
|
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam(
|
||||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||||
return this->op_->out_connector_->Add(std::move(copy_row));
|
return this->op_->out_connector_->Add(std::move(copy_row));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,7 +296,7 @@ class ParallelOp : public DatasetOp {
|
||||||
"[Internal Error] Inserting another row to cache while cache is already full of unused rows.");
|
"[Internal Error] Inserting another row to cache while cache is already full of unused rows.");
|
||||||
TensorRow copy_row;
|
TensorRow copy_row;
|
||||||
RETURN_IF_NOT_OK(row.Clone(©_row));
|
RETURN_IF_NOT_OK(row.Clone(©_row));
|
||||||
backup_rows.emplace_front(std::move(copy_row));
|
(void)backup_rows.emplace_front(std::move(copy_row));
|
||||||
backup_index_++;
|
backup_index_++;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ Status ShuffleOp::PrepareOperator() {
|
||||||
// in reset mode, we need to move forward the random generator seed.
|
// in reset mode, we need to move forward the random generator seed.
|
||||||
if (GlobalContext::config_manager()->fast_recovery() && op_current_repeats_ > 0) {
|
if (GlobalContext::config_manager()->fast_recovery() && op_current_repeats_ > 0) {
|
||||||
for (auto i = 0; i < op_current_repeats_; i++) {
|
for (auto i = 0; i < op_current_repeats_; i++) {
|
||||||
SelfReset();
|
RETURN_IF_NOT_OK(SelfReset());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -68,8 +68,8 @@ Status MappableLeafOp::operator()() {
|
||||||
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
|
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
|
||||||
TensorRow sample_row;
|
TensorRow sample_row;
|
||||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
|
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
|
||||||
while (true) { // each iteration is 1 repeat (usually =1 epoch, unless we have a repeat node above us), breaks when
|
for (;;) { // each iteration is 1 repeat (usually =1 epoch, unless we have a repeat node above us), breaks when
|
||||||
// IsLastIteration() is true
|
// IsLastIteration() is true
|
||||||
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
|
if (op_current_repeats_ % GetOpNumRepeatsPerEpoch() == 0) {
|
||||||
ep_step = 0;
|
ep_step = 0;
|
||||||
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
|
|
|
@ -41,8 +41,8 @@ NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connect
|
||||||
load_io_block_queue_(true),
|
load_io_block_queue_(true),
|
||||||
shuffle_files_(shuffle_files),
|
shuffle_files_(shuffle_files),
|
||||||
num_rows_per_shard_(0),
|
num_rows_per_shard_(0),
|
||||||
num_rows_(0),
|
|
||||||
compression_type_(compression_type),
|
compression_type_(compression_type),
|
||||||
|
num_rows_(0),
|
||||||
shuffled_keys_({}),
|
shuffled_keys_({}),
|
||||||
seed_(0) {
|
seed_(0) {
|
||||||
worker_connector_size_ = worker_connector_size;
|
worker_connector_size_ = worker_connector_size;
|
||||||
|
|
|
@ -58,7 +58,7 @@ class DistributedSamplerRT : public SamplerRT {
|
||||||
/// \brief Reset for next epoch.
|
/// \brief Reset for next epoch.
|
||||||
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
|
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
|
||||||
/// \return Status The status code returned
|
/// \return Status The status code returned
|
||||||
Status ResetSampler(const bool failover_reset = false) override;
|
Status ResetSampler(const bool failover_reset) override;
|
||||||
|
|
||||||
int64_t GetDeviceID() { return device_id_; }
|
int64_t GetDeviceID() { return device_id_; }
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class PythonSamplerRT : public SamplerRT {
|
||||||
/// \brief Reset for next epoch.
|
/// \brief Reset for next epoch.
|
||||||
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
|
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
|
||||||
/// \return Status The status code returned
|
/// \return Status The status code returned
|
||||||
Status ResetSampler(const bool failover_reset = false) override;
|
Status ResetSampler(const bool failover_reset) override;
|
||||||
|
|
||||||
// Op calls this to get next Sample that contains all the sampleIds
|
// Op calls this to get next Sample that contains all the sampleIds
|
||||||
// @param TensorRow to be returned to corresponding Dataset Op
|
// @param TensorRow to be returned to corresponding Dataset Op
|
||||||
|
|
|
@ -42,7 +42,7 @@ class SequentialSamplerRT : public SamplerRT {
|
||||||
/// \brief Reset for next epoch.
|
/// \brief Reset for next epoch.
|
||||||
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
|
/// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
|
||||||
/// \return Status The status code returned
|
/// \return Status The status code returned
|
||||||
Status ResetSampler(const bool failover_reset = false) override;
|
Status ResetSampler(const bool failover_reset) override;
|
||||||
|
|
||||||
// Op calls this to get next Sample that contains all the sampleIds
|
// Op calls this to get next Sample that contains all the sampleIds
|
||||||
// @param TensorRow to be returned to corresponding Dataset Op
|
// @param TensorRow to be returned to corresponding Dataset Op
|
||||||
|
|
|
@ -240,8 +240,8 @@ Status TreeAdapter::Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_
|
||||||
Status TreeAdapter::AdjustReset(const int64_t epoch_num) {
|
Status TreeAdapter::AdjustReset(const int64_t epoch_num) {
|
||||||
if (GlobalContext::config_manager()->fast_recovery() && epoch_num > 0) {
|
if (GlobalContext::config_manager()->fast_recovery() && epoch_num > 0) {
|
||||||
MS_LOG(INFO) << "Adjusting dataset pipeline for failover reset to start on epoch: " << (epoch_num + 1);
|
MS_LOG(INFO) << "Adjusting dataset pipeline for failover reset to start on epoch: " << (epoch_num + 1);
|
||||||
for (auto op = tree_->begin(); op != tree_->end(); op++) {
|
for (auto op = tree_->begin(); op != tree_->end(); ++op) {
|
||||||
op->SetEpoch(epoch_num);
|
RETURN_IF_NOT_OK(op->SetEpoch(epoch_num));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
||||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_DATASETS_H_
|
||||||
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
|
@ -124,7 +124,7 @@ def _reset_training_dataset(step, epoch):
|
||||||
"""
|
"""
|
||||||
dataset = _get_training_dataset()
|
dataset = _get_training_dataset()
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
dataset._reset(step, epoch) # pylint: disable=W0212
|
dataset._reset(step, epoch) # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Training dataset is not set.")
|
raise RuntimeError("Training dataset is not set.")
|
||||||
|
|
||||||
|
@ -3698,9 +3698,6 @@ class _ToDevice:
|
||||||
def send(self):
|
def send(self):
|
||||||
self._to_device.Send()
|
self._to_device.Send()
|
||||||
|
|
||||||
def _reset(self, step, epoch):
|
|
||||||
self._to_device.Reset(step, epoch)
|
|
||||||
|
|
||||||
def stop_send(self):
|
def stop_send(self):
|
||||||
"""
|
"""
|
||||||
send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
|
send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
|
||||||
|
@ -3739,6 +3736,9 @@ class _ToDevice:
|
||||||
offload_model = GetOffloadModel(self._to_device, col_names)
|
offload_model = GetOffloadModel(self._to_device, col_names)
|
||||||
return offload_model
|
return offload_model
|
||||||
|
|
||||||
|
def _reset(self, step, epoch):
|
||||||
|
self._to_device.Reset(step, epoch)
|
||||||
|
|
||||||
|
|
||||||
class TransferDataset(Dataset):
|
class TransferDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
@ -3809,11 +3809,6 @@ class TransferDataset(Dataset):
|
||||||
if self._to_device is not None:
|
if self._to_device is not None:
|
||||||
self._to_device.continue_send()
|
self._to_device.continue_send()
|
||||||
|
|
||||||
def _reset(self, step, epoch):
|
|
||||||
if self._to_device is not None:
|
|
||||||
logger.info("Reset the dataset pipeline to step: " + str(step) + ", epoch: " + str(epoch))
|
|
||||||
self._to_device._reset(step, epoch) # pylint: disable=W0212
|
|
||||||
|
|
||||||
def get_data_info(self):
|
def get_data_info(self):
|
||||||
"""
|
"""
|
||||||
Get type and shape of current batch
|
Get type and shape of current batch
|
||||||
|
@ -3835,6 +3830,11 @@ class TransferDataset(Dataset):
|
||||||
if self._to_device is not None:
|
if self._to_device is not None:
|
||||||
self._to_device.release()
|
self._to_device.release()
|
||||||
|
|
||||||
|
def _reset(self, step, epoch):
|
||||||
|
if self._to_device is not None:
|
||||||
|
logger.info("Reset the dataset pipeline to step: " + str(step) + ", epoch: " + str(epoch))
|
||||||
|
self._to_device._reset(step, epoch) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
class Schema:
|
class Schema:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue