!46291 [MD] Process Error Samples Replace and Skip Support for Map op
Merge pull request !46291 from cathwong/ckw_error_samples
This commit is contained in:
commit
80bd3c0f86
|
@ -79,6 +79,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
|
|||
.def("get_fast_recovery", &ConfigManager::fast_recovery)
|
||||
.def("set_debug_mode", &ConfigManager::set_debug_mode)
|
||||
.def("get_debug_mode", &ConfigManager::get_debug_mode)
|
||||
.def("set_error_samples_mode", &ConfigManager::set_error_samples_mode)
|
||||
.def("get_error_samples_mode", &ConfigManager::get_error_samples_mode)
|
||||
.def("load", [](ConfigManager &c, const std::string &s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
||||
}));
|
||||
|
||||
|
@ -199,5 +201,13 @@ PYBIND_REGISTER(ImageReadMode, 0, ([](const py::module *m) {
|
|||
.value("DE_IMAGE_READ_MODE_COLOR", ImageReadMode::kCOLOR)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ErrorSamplesMode, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<ErrorSamplesMode>(*m, "ErrorSamplesMode", py::arithmetic())
|
||||
.value("DE_ERROR_SAMPLES_MODE_RETURN", ErrorSamplesMode::kReturn)
|
||||
.value("DE_ERROR_SAMPLES_MODE_REPLACE", ErrorSamplesMode::kReplace)
|
||||
.value("DE_ERROR_SAMPLES_MODE_SKIP", ErrorSamplesMode::kSkip)
|
||||
.export_values();
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -89,6 +89,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
|
|||
set_op_connector_size(j.value("opConnectorSize", op_connector_size_));
|
||||
set_seed(j.value("seed", seed_));
|
||||
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
|
||||
set_fast_recovery(j.value("fast_recovery", fast_recovery_));
|
||||
set_error_samples_mode(j.value("error_samples_mode", error_samples_mode_));
|
||||
set_cache_host(j.value("cacheHost", cache_host_));
|
||||
set_cache_port(j.value("cachePort", cache_port_));
|
||||
set_num_connections(j.value("numConnections", num_connections_));
|
||||
|
|
|
@ -38,6 +38,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
const char kEmptyString[] = "";
|
||||
const char kJsonExtension[] = ".json";
|
||||
|
||||
// The ConfigManager is a class for managing default values. When a user is constructing any objects
|
||||
// in the framework, often they may choose to omit some settings instead of overriding them.
|
||||
// This class manages some of the default values, for cases when the user does not manually specify
|
||||
|
@ -285,7 +286,8 @@ class ConfigManager {
|
|||
|
||||
// setter function
|
||||
// @notes User must also set the seed to be able to get same augmentations
|
||||
// @notes Fast recovery can cause slightly different random augmentations than original run (default=true)
|
||||
// @notes Fast recovery can cause slightly different random augmentations than original run
|
||||
// (System default = true)
|
||||
// @param fast_recovery - Set whether MD pipeline recovers fast in failover reset
|
||||
void set_fast_recovery(const bool fast_recovery) { fast_recovery_ = fast_recovery; }
|
||||
|
||||
|
@ -301,6 +303,22 @@ class ConfigManager {
|
|||
// @return - Flag to indicate whether the debug mode is on
|
||||
bool get_debug_mode() const { return debug_mode_flag_; }
|
||||
|
||||
// setter function
|
||||
// @param error_samples_mode - Set the method in which erroneous samples should be processed
|
||||
// (System default = ErrorSamplesMode::kReturn)
|
||||
// @notes For replacement of erroneous samples, MD will select a deterministic but "random" sample.
|
||||
void set_error_samples_mode(const ErrorSamplesMode error_samples_mode) { error_samples_mode_ = error_samples_mode; }
|
||||
|
||||
// getter function
|
||||
// @return - The method in which erroneous samples should be processed in a dataset pipeline
|
||||
// @notes This method is used for external configuration API which returns integer type
|
||||
int32_t get_error_samples_mode() const { return static_cast<int>(error_samples_mode_); }
|
||||
|
||||
// getter function
|
||||
// @return - The method in which erroneous samples should be processed in a dataset pipeline
|
||||
// @notes This method is used for internal processing, using enum type
|
||||
ErrorSamplesMode error_samples_mode() const { return error_samples_mode_; }
|
||||
|
||||
private:
|
||||
// Private helper function that takes a nlohmann json format and populates the settings
|
||||
// @param j - The json nlohmann json info
|
||||
|
@ -337,6 +355,7 @@ class ConfigManager {
|
|||
bool dynamic_shape_{false};
|
||||
bool fast_recovery_{true}; // Used for failover scenario to recover quickly or produce same augmentations
|
||||
bool debug_mode_flag_{false}; // Indicator for debug mode
|
||||
ErrorSamplesMode error_samples_mode_{ErrorSamplesMode::kReturn}; // The method to process erroneous samples
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,9 @@ class TensorRow {
|
|||
kFlagEOE = 1u << 1, // The row is an eoe end-of-epoch msg
|
||||
kFlagWait = 1u << 2, // The row is an control signal for workers to suspend operations
|
||||
kFlagQuit = 1u << 3, // The row is a control signal for workers to quit
|
||||
kFlagSkip = 1u << 4 // The row is a control signal for workers to skip this row
|
||||
kFlagSkip = 1u << 4, // The row is a control signal for workers to skip this row
|
||||
kFlagError = 1u << 5 // The row is an error row (needs to be replaced with another row or skipped, as per
|
||||
// ErrorSamplesMode config)
|
||||
};
|
||||
|
||||
// Type definitions
|
||||
|
@ -235,6 +237,10 @@ class TensorRow {
|
|||
return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagSkip));
|
||||
}
|
||||
|
||||
bool error() const {
|
||||
return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagError));
|
||||
}
|
||||
|
||||
const TensorRowFlags Flags() { return tensor_row_flag_; }
|
||||
|
||||
explicit TensorRow(TensorRowFlags flag);
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
|
@ -256,7 +257,23 @@ Status MapOp::WorkerCompute(const TensorRow &in_row, TensorRow *out_row,
|
|||
for (size_t i = 0; i < job_list.size(); i++) {
|
||||
RETURN_IF_INTERRUPTED();
|
||||
// Execute MapWorkerJob.
|
||||
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
|
||||
Status rc = job_list[i]->Run(job_input_table, &result_table);
|
||||
if (rc.IsError()) {
|
||||
if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kReplace) {
|
||||
MS_LOG(WARNING)
|
||||
<< "Detected an erroneous sample in MindData Map operation, and will replace with a healthy sample: " +
|
||||
rc.GetErrDescription();
|
||||
*out_row = TensorRow(TensorRow::kFlagError);
|
||||
return Status::OK();
|
||||
} else if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kSkip) {
|
||||
MS_LOG(WARNING) << "Detected an erroneous sample in MindData Map operation, and will skip this sample: " +
|
||||
rc.GetErrDescription();
|
||||
*out_row = TensorRow(TensorRow::kFlagError);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
|
||||
if (i + 1 < job_list.size()) {
|
||||
job_input_table = std::move(result_table);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -30,6 +31,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr int64_t kCachedRowsSize = 16;
|
||||
|
||||
class ExecutionTree;
|
||||
|
||||
|
@ -162,16 +164,155 @@ class ParallelOp : public DatasetOp {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
class RowHandlingStrategy {
|
||||
public:
|
||||
explicit RowHandlingStrategy(ParallelOp *op) : op_(op) {}
|
||||
|
||||
virtual Status HandleHealthyRow(const TensorRow &row) {
|
||||
++this->op_->ep_step_;
|
||||
++this->op_->total_step_;
|
||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
|
||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||
return this->op_->out_connector_->Add(std::move(row));
|
||||
}
|
||||
virtual Status HandleErrorRow(const TensorRow &row) = 0;
|
||||
|
||||
virtual Status HandleEOE(const TensorRow &row) {
|
||||
this->op_->current_repeats_++;
|
||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||
if (this->op_->current_repeats_ % this->op_->GetOpNumRepeatsPerEpoch() == 0) {
|
||||
this->op_->current_epochs_++;
|
||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.EpochEnd(
|
||||
CallbackParam(this->op_->current_epochs_, this->op_->ep_step_, this->op_->total_step_)));
|
||||
this->op_->ep_step_ = 0;
|
||||
}
|
||||
return op_->out_connector_->Add(std::move(row));
|
||||
}
|
||||
virtual Status HandleEOF(const TensorRow &row) {
|
||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.End(
|
||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||
return op_->out_connector_->Add(std::move(row));
|
||||
}
|
||||
|
||||
protected:
|
||||
ParallelOp *op_;
|
||||
};
|
||||
|
||||
class ErrorStrategy : public RowHandlingStrategy {
|
||||
public:
|
||||
using RowHandlingStrategy::RowHandlingStrategy;
|
||||
Status HandleErrorRow(const TensorRow &row) override {
|
||||
return Status(StatusCode::kMDUnexpectedError,
|
||||
"[Internal Error] Error row is detected in collector while Error strategy is set to error out!");
|
||||
}
|
||||
};
|
||||
|
||||
class SkipStrategy : public RowHandlingStrategy {
|
||||
public:
|
||||
using RowHandlingStrategy::RowHandlingStrategy;
|
||||
Status HandleErrorRow(const TensorRow &row) override { return Status::OK(); }
|
||||
};
|
||||
|
||||
class ReplaceStrategy : public RowHandlingStrategy {
|
||||
public:
|
||||
using RowHandlingStrategy::RowHandlingStrategy;
|
||||
|
||||
Status HandleHealthyRow(const TensorRow &row) override {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
|
||||
"[Internal Error] Number of cached rows is beyond the number set.");
|
||||
if (backup_index_ < kCachedRowsSize - 1) { // cache has used row(s)
|
||||
if (IsCacheFull()) {
|
||||
// remove the last element from cache (a used row)
|
||||
PopFromCache();
|
||||
}
|
||||
AddToCache(row);
|
||||
} else { // cache is full of unused rows
|
||||
if (missing_errors_ > 0) {
|
||||
// send a cached row to next op and cache the current row
|
||||
RETURN_IF_NOT_OK(AddFromCache());
|
||||
missing_errors_--;
|
||||
AddToCache(row);
|
||||
}
|
||||
}
|
||||
// send the healthy row to next op
|
||||
++this->op_->ep_step_;
|
||||
++this->op_->total_step_;
|
||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
|
||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||
return this->op_->out_connector_->Add(std::move(row));
|
||||
}
|
||||
|
||||
Status HandleErrorRow(const TensorRow &row) override {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
|
||||
"[Internal Error] Number of cached rows is beyond the number set.");
|
||||
// cache is not full of unused rows
|
||||
if (backup_index_ != kCachedRowsSize - 1) {
|
||||
missing_errors_++;
|
||||
return Status::OK();
|
||||
}
|
||||
// cache is full of unused rows and we have an error row
|
||||
return AddFromCache();
|
||||
}
|
||||
|
||||
Status HandleEOE(const TensorRow &row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(missing_errors_ == 0 || !IsCacheEmpty(),
|
||||
"All data is garbage and cannot be replaced.");
|
||||
// send outstanding rows first and then send eoe
|
||||
while (missing_errors_ > 0) {
|
||||
RETURN_IF_NOT_OK(AddFromCache());
|
||||
missing_errors_--;
|
||||
}
|
||||
return RowHandlingStrategy::HandleEOE(row);
|
||||
}
|
||||
|
||||
private:
|
||||
Status AddFromCache() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() > 0, "Cannot add a row from cache since cache is empty!");
|
||||
// Note: If backup_index_ is negative (error samples at the end of data),
|
||||
// the modulo division with size_t will be result in 0, and backup_rows[0] accessed.
|
||||
auto cached_row = backup_rows[backup_index_ % backup_rows.size()];
|
||||
backup_index_--;
|
||||
++this->op_->ep_step_;
|
||||
++this->op_->total_step_;
|
||||
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
|
||||
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
|
||||
return this->op_->out_connector_->Add(std::move(cached_row));
|
||||
}
|
||||
|
||||
Status AddToCache(TensorRow row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() < kCachedRowsSize,
|
||||
"[Internal Error] Inserting another row to cache while cache is already full.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
backup_index_ < kCachedRowsSize - 1,
|
||||
"[Internal Error] Inserting another row to cache while cache is already full of unused rows.");
|
||||
backup_rows.push_front(row);
|
||||
backup_index_++;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PopFromCache() { backup_rows.pop_back(); }
|
||||
bool IsCacheFull() const { return backup_rows.size() == kCachedRowsSize; }
|
||||
bool IsCacheEmpty() const { return backup_rows.size() == 0; }
|
||||
std::deque<TensorRow> backup_rows{}; // will hold a copy of some healthy rows collected (NOT error, skip, eoe, eof)
|
||||
int32_t backup_index_{-1}; // index of the backup we should pick next time (can be negative if we run out of
|
||||
// unused cached rows)
|
||||
int32_t missing_errors_{0}; // the number of unaddressed error rows (that we need to send a replacement to output)
|
||||
};
|
||||
|
||||
virtual Status Collector() {
|
||||
TaskManager::FindMe()->Post();
|
||||
// num_rows received, including eoe, num_step of current epoch
|
||||
int64_t num_rows = 0, ep_step = 0, total_step = 0;
|
||||
int32_t current_repeats = 0, current_epochs = 0;
|
||||
// num_rows received, including eoe,
|
||||
int64_t num_rows = 0;
|
||||
current_repeats_ = 0;
|
||||
current_epochs_ = 0;
|
||||
SetStrategy();
|
||||
// num_step of current epoch and the total
|
||||
ep_step_ = 0, total_step_ = 0;
|
||||
TensorRow row;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[static_cast<const int>(num_rows++ % num_workers_)]->PopFront(&row));
|
||||
if (row.wait()) {
|
||||
// When collector receives the signal from workere thread, it increments a atomic int
|
||||
// When collector receives the signal from worker thread, it increments an atomic int
|
||||
// If num_worker signals are received, wakes up the main thread
|
||||
if (++num_workers_paused_ == num_workers_) {
|
||||
wait_for_workers_post_.Set();
|
||||
|
@ -179,23 +320,16 @@ class ParallelOp : public DatasetOp {
|
|||
}
|
||||
continue;
|
||||
} else if (row.eoe()) {
|
||||
current_repeats++;
|
||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||
if (current_repeats % GetOpNumRepeatsPerEpoch() == 0) {
|
||||
current_epochs++;
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(current_epochs, ep_step, total_step)));
|
||||
ep_step = 0;
|
||||
}
|
||||
RETURN_IF_NOT_OK(strategy_->HandleEOE(row));
|
||||
} else if (row.eof()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(current_epochs + 1, ep_step, total_step)));
|
||||
RETURN_IF_NOT_OK(strategy_->HandleEOF(row));
|
||||
} else if (row.skip()) {
|
||||
continue;
|
||||
} else if (row.error()) {
|
||||
RETURN_IF_NOT_OK(strategy_->HandleErrorRow(row));
|
||||
} else if (row.Flags() == TensorRow::TensorRowFlags::kFlagNone) {
|
||||
++ep_step;
|
||||
++total_step;
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(current_epochs + 1, ep_step, total_step)));
|
||||
RETURN_IF_NOT_OK(strategy_->HandleHealthyRow(row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
|
||||
} while (!row.eof());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -223,6 +357,17 @@ class ParallelOp : public DatasetOp {
|
|||
public:
|
||||
int32_t NumWorkers() override { return num_workers_; }
|
||||
|
||||
private:
|
||||
void SetStrategy() {
|
||||
if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kSkip) {
|
||||
strategy_ = std::make_shared<SkipStrategy>(this);
|
||||
} else if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kReplace) {
|
||||
strategy_ = std::make_shared<ReplaceStrategy>(this);
|
||||
} else {
|
||||
strategy_ = std::make_shared<ErrorStrategy>(this);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::atomic_int next_worker_id_;
|
||||
|
||||
|
@ -234,6 +379,13 @@ class ParallelOp : public DatasetOp {
|
|||
QueueList<T> worker_in_queues_;
|
||||
/// queues to hold the output from workers
|
||||
QueueList<S> worker_out_queues_;
|
||||
|
||||
private:
|
||||
std::shared_ptr<RowHandlingStrategy> strategy_;
|
||||
int32_t ep_step_{0};
|
||||
int32_t total_step_{0};
|
||||
int32_t current_epochs_{0};
|
||||
int32_t current_repeats_{0};
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -291,6 +291,13 @@ enum class DATASET_API ResampleMethod {
|
|||
kKaiserWindow = 1, ///< Resample audio by Kaiser window
|
||||
};
|
||||
|
||||
/// \brief Possible configuration methods for processing error samples.
|
||||
enum class DATASET_API ErrorSamplesMode {
|
||||
kReturn = 0, ///< Erroneous sample results in error raised and returned
|
||||
kReplace = 1, ///< Erroneous sample is replaced with an internally determined sample
|
||||
kSkip = 2 ///< Erroneous sample is skipped
|
||||
};
|
||||
|
||||
/// \brief Convenience function to check bitmask for a 32bit int
|
||||
/// \param[in] bits a 32bit int to be tested
|
||||
/// \param[in] bitMask a 32bit int representing bit mask
|
||||
|
|
|
@ -23,14 +23,14 @@ Common imported modules in corresponding API examples are as follows:
|
|||
import mindspore.dataset as ds
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from enum import IntEnum
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import numpy
|
||||
import mindspore._c_dataengine as cde
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.core.validator_helpers import replace_none
|
||||
from mindspore.dataset.core.validator_helpers import replace_none, type_check
|
||||
|
||||
__all__ = ['set_sending_batches', 'load', '_init_device_info',
|
||||
'set_seed', 'get_seed',
|
||||
|
@ -46,8 +46,9 @@ __all__ = ['set_sending_batches', 'load', '_init_device_info',
|
|||
'set_auto_offload', 'get_auto_offload',
|
||||
'set_enable_watchdog', 'get_enable_watchdog',
|
||||
'set_fast_recovery', 'get_fast_recovery',
|
||||
'set_multiprocessing_timeout_interval', 'get_multiprocessing_timeout_interval',
|
||||
'set_debug_mode', 'get_debug_mode']
|
||||
'set_debug_mode', 'get_debug_mode',
|
||||
'set_error_samples_mode', 'get_error_samples_mode', 'ErrorSamplesMode',
|
||||
'set_multiprocessing_timeout_interval', 'get_multiprocessing_timeout_interval']
|
||||
|
||||
INT32_MAX = 2147483647
|
||||
UINT32_MAX = 4294967295
|
||||
|
@ -809,7 +810,7 @@ def set_fast_recovery(fast_recovery):
|
|||
(yet with slightly different random augmentations).
|
||||
|
||||
Args:
|
||||
fast_recovery (bool): Whether the dataset pipeline recovers in fast mode.
|
||||
fast_recovery (bool): Whether the dataset pipeline recovers in fast mode. System default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If `fast_recovery` is not a boolean data type.
|
||||
|
@ -873,3 +874,82 @@ def get_debug_mode():
|
|||
>>> debug_mode = ds.config.get_debug_mode()
|
||||
"""
|
||||
return _config.get_debug_mode()
|
||||
|
||||
|
||||
class ErrorSamplesMode(IntEnum):
|
||||
"""
|
||||
An enumeration for `error_samples_mode` .
|
||||
|
||||
Possible enumeration values are: ErrorSamplesMode.RETURN, ErrorSamplesMode.REPLACE, ErrorSamplesMode.SKIP.
|
||||
|
||||
- ErrorSamplesMode.RETURN: means erroneous sample results in error raised and returned.
|
||||
- rrorSamplesMode.REPLACE: means erroneous sample is replaced with an internally determined sample.
|
||||
- ErrorSamplesMode.SKIP: means erroneous sample is skipped.
|
||||
"""
|
||||
|
||||
RETURN = 0
|
||||
REPLACE = 1
|
||||
SKIP = 2
|
||||
|
||||
|
||||
# Convert ErrorSamplesMode from Python enum format to CDE enum format
|
||||
_PYTHON_TO_CDE_ERROR_SAMPLES_MODE = {
|
||||
ErrorSamplesMode.RETURN: cde.ErrorSamplesMode.DE_ERROR_SAMPLES_MODE_RETURN,
|
||||
ErrorSamplesMode.REPLACE: cde.ErrorSamplesMode.DE_ERROR_SAMPLES_MODE_REPLACE,
|
||||
ErrorSamplesMode.SKIP: cde.ErrorSamplesMode.DE_ERROR_SAMPLES_MODE_SKIP
|
||||
}
|
||||
|
||||
# Convert ErrorSamplesMode from CDE int format to Python enum format
|
||||
_CDE_TO_PYTHON_ERROR_SAMPLES_MODE = {
|
||||
0: ErrorSamplesMode.RETURN,
|
||||
1: ErrorSamplesMode.REPLACE,
|
||||
2: ErrorSamplesMode.SKIP
|
||||
}
|
||||
|
||||
|
||||
def set_error_samples_mode(error_samples_mode):
|
||||
"""
|
||||
Set the method in which erroneous samples should be processed in a dataset pipeline.
|
||||
|
||||
Note:
|
||||
1. This error samples feature is only applicable to the Map operation in a dataset pipeline.
|
||||
2. For replacement mode, a cache of internally determined samples will be used.
|
||||
3. If skip mode is used in a distributed setting, beware to manually ensure the
|
||||
number of valid samples are the same for each shard (otherwise one may encounter hangs).
|
||||
One technique is to manually concat a dataset of all valid samples plus a
|
||||
take operation for the number of skipped erroneous samples.
|
||||
|
||||
Args:
|
||||
error_samples_mode (ErrorSamplesMode): The method in which erroneous samples should be processed in a dataset
|
||||
pipeline. It can be any of [ErrorSamplesMode.RETURN, ErrorSamplesMode.REPLACE, ErrorSamplesMode.SKIP].
|
||||
System default: ErrorSamplesMode.RETURN.
|
||||
|
||||
- ErrorSamplesMode.RETURN: means erroneous sample results in error raised and returned.
|
||||
|
||||
- ErrorSamplesMode.REPLACE: means erroneous sample is replaced with an internally determined sample.
|
||||
|
||||
- ErrorSamplesMode.SKIP: means erroneous sample is skipped.
|
||||
|
||||
Raises:
|
||||
TypeError: If `error_samples_mode` is not of type ErrorSamplesMode.
|
||||
|
||||
Examples:
|
||||
>>> ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.SKIP)
|
||||
"""
|
||||
type_check(error_samples_mode, (ErrorSamplesMode,), "error_samples_mode")
|
||||
_config.set_error_samples_mode(_PYTHON_TO_CDE_ERROR_SAMPLES_MODE.get(error_samples_mode))
|
||||
|
||||
|
||||
def get_error_samples_mode():
|
||||
"""
|
||||
Get the current configuration for method for processing erroneous samples in a dataset pipeline.
|
||||
|
||||
Returns:
|
||||
ErrorSamplesMode, The method in which erroneous samples should be processed in a dataset pipeline.
|
||||
- ErrorSamplesMode.RETURN: means erroneous sample results in error raised and returned.
|
||||
- ErrorSamplesMode.REPLACE: means erroneous sample is replaced with an internally determined sample.
|
||||
- ErrorSamplesMode.SKIP: means erroneous sample is skipped.
|
||||
Examples:
|
||||
>>> error_samples_mode = ds.config.get_error_samples_mode()
|
||||
"""
|
||||
return _CDE_TO_PYTHON_ERROR_SAMPLES_MODE.get(_config.get_error_samples_mode())
|
||||
|
|
|
@ -2764,7 +2764,7 @@ def _worker_loop(operations, pipe, seed=get_seed()):
|
|||
pipe.worker_send(output_tensors)
|
||||
except Exception:
|
||||
pipe.worker_send(ExceptionHandler(where="in map(or batch) worker and execute Python function"))
|
||||
return
|
||||
# Do not return
|
||||
|
||||
|
||||
def worker_target(operations, seed=get_seed()):
|
||||
|
|
|
@ -6,5 +6,7 @@
|
|||
"opConnectorSize": 16,
|
||||
"seed": 5489,
|
||||
"monitorSamplingInterval": 15,
|
||||
"debug_mode_flag": true
|
||||
"fast_recovery": true,
|
||||
"debug_mode_flag": true,
|
||||
"error_samples_mode": 1
|
||||
}
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"datasetType": "IMAGENET",
|
||||
"numRows": 3,
|
||||
"columns": {
|
||||
"image": {
|
||||
"type": "uint8",
|
||||
"rank": 1,
|
||||
"t_impl": "cvmat"
|
||||
},
|
||||
"label" : {
|
||||
"type": "uint32",
|
||||
"rank": 0,
|
||||
"t_impl" : "flex"
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1 @@
|
|||
Hello world!
|
|
@ -25,6 +25,7 @@ import mindspore.dataset as ds
|
|||
import mindspore.dataset.engine.iterators as it
|
||||
import mindspore.dataset.transforms
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.dataset.core.config as config
|
||||
from mindspore import log as logger
|
||||
from util import dataset_equal
|
||||
|
||||
|
@ -42,6 +43,7 @@ def config_error_func(config_interface, input_args, err_type, except_err_msg):
|
|||
assert except_err_msg in err_msg
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_basic():
|
||||
"""
|
||||
Feature: Config
|
||||
|
@ -55,6 +57,7 @@ def test_basic():
|
|||
monitor_sampling_interval_original = ds.config.get_monitor_sampling_interval()
|
||||
fast_recovery_original = ds.config.get_fast_recovery()
|
||||
debug_mode = ds.config.get_debug_mode()
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
|
||||
ds.config.load('../data/dataset/declient.cfg')
|
||||
|
||||
|
@ -65,6 +68,7 @@ def test_basic():
|
|||
assert ds.config.get_monitor_sampling_interval() == 15
|
||||
assert ds.config.get_fast_recovery()
|
||||
assert ds.config.get_debug_mode()
|
||||
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.REPLACE
|
||||
|
||||
ds.config.set_num_parallel_workers(2)
|
||||
# ds.config.set_worker_connector_size(3)
|
||||
|
@ -73,6 +77,7 @@ def test_basic():
|
|||
ds.config.set_monitor_sampling_interval(45)
|
||||
ds.config.set_fast_recovery(False)
|
||||
ds.config.set_debug_mode(False)
|
||||
ds.config.set_error_samples_mode(config.ErrorSamplesMode.RETURN)
|
||||
|
||||
assert ds.config.get_num_parallel_workers() == 2
|
||||
# assert ds.config.get_worker_connector_size() == 3
|
||||
|
@ -81,6 +86,13 @@ def test_basic():
|
|||
assert ds.config.get_monitor_sampling_interval() == 45
|
||||
assert not ds.config.get_fast_recovery()
|
||||
assert not ds.config.get_debug_mode()
|
||||
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.RETURN
|
||||
|
||||
ds.config.set_fast_recovery(True)
|
||||
ds.config.set_error_samples_mode(config.ErrorSamplesMode.SKIP)
|
||||
|
||||
assert ds.config.get_fast_recovery()
|
||||
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.SKIP
|
||||
|
||||
# Restore original configuration values
|
||||
ds.config.set_num_parallel_workers(num_parallel_workers_original)
|
||||
|
@ -89,6 +101,7 @@ def test_basic():
|
|||
ds.config.set_monitor_sampling_interval(monitor_sampling_interval_original)
|
||||
ds.config.set_fast_recovery(fast_recovery_original)
|
||||
ds.config.set_debug_mode(debug_mode)
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
|
||||
|
||||
def test_get_seed():
|
||||
|
@ -521,6 +534,7 @@ def test_fast_recovery():
|
|||
assert "set_fast_recovery() missing 1 required positional argument: 'fast_recovery'" in str(error_info.value)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_debug_mode():
|
||||
"""
|
||||
Feature: Test the debug mode setter/getter function
|
||||
|
@ -548,6 +562,40 @@ def test_debug_mode():
|
|||
ds.config.set_debug_mode(origin_debug_mode_flag)
|
||||
|
||||
|
||||
def test_error_samples_mode():
|
||||
"""
|
||||
Feature: Test the get_error_samples_mode function
|
||||
Description: This function only accepts ErrorSamplesMode enum values as input and outputs error otherwise
|
||||
Expectation: For error input, error is raised as expected
|
||||
"""
|
||||
# set_error_samples_mode will raise TypeError if input is boolean
|
||||
config_error_func(config.set_error_samples_mode, False, TypeError,
|
||||
"is not of type [<enum 'ErrorSamplesMode'>]")
|
||||
# set_error_samples_mode will raise TypeError if input is int
|
||||
config_error_func(config.set_error_samples_mode, 1, TypeError,
|
||||
"is not of type [<enum 'ErrorSamplesMode'>]")
|
||||
# set_error_samples_mode will raise TypeError if input is a string
|
||||
config_error_func(config.set_error_samples_mode, "Zero", TypeError,
|
||||
"is not of type [<enum 'ErrorSamplesMode'>]")
|
||||
# set_error_samples_mode will raise TypeError if input is a tuple
|
||||
config_error_func(config.set_error_samples_mode, (1,), TypeError,
|
||||
"is not of type [<enum 'ErrorSamplesMode'>]")
|
||||
# set_error_samples_mode will raise TypeError if input is None
|
||||
config_error_func(config.set_error_samples_mode, None, TypeError,
|
||||
"is not of type [<enum 'ErrorSamplesMode'>]")
|
||||
|
||||
# set_error_samples_mode will raise TypeError if no input is provided
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
config.set_error_samples_mode()
|
||||
assert "set_error_samples_mode() missing 1 required positional argument: 'error_samples_mode'" in \
|
||||
str(error_info.value)
|
||||
|
||||
# set_error_samples_mode will raise TypeError if too many parameters are provided
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
config.set_error_samples_mode(config.ErrorSamplesMode.REPLACE, 10)
|
||||
assert "set_error_samples_mode() takes 1 positional argument but 2 were given" in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_basic()
|
||||
test_get_seed()
|
||||
|
@ -565,3 +613,4 @@ if __name__ == '__main__':
|
|||
test_config_bool_type_error()
|
||||
test_fast_recovery()
|
||||
test_debug_mode()
|
||||
test_error_samples_mode()
|
||||
|
|
|
@ -1022,20 +1022,19 @@ def test_imagefolder_error_sample_sourceop():
|
|||
Expectation: The dataset is processed as expected.
|
||||
"""
|
||||
|
||||
def test_config(my_seed, my_error_sample_data_file):
|
||||
# Set configuration
|
||||
# Note: This test depends on the seed value. An expected exception is not raised with some other seed values.
|
||||
def test_config(my_seed, my_error_sample_data_file, my_total_samples):
|
||||
# Set configuration since default Random Sampler is used
|
||||
original_seed = config_get_set_seed(my_seed)
|
||||
|
||||
# For ImageFolderDataset, use decode=False
|
||||
data1 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1, decode=False)
|
||||
data1 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, decode=False)
|
||||
count = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == 3
|
||||
assert count == my_total_samples
|
||||
|
||||
# For ImageFolderDataset, use decode=True
|
||||
data2 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1, decode=True)
|
||||
data2 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, decode=True)
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
|
@ -1045,11 +1044,11 @@ def test_imagefolder_error_sample_sourceop():
|
|||
ds.config.set_seed(original_seed)
|
||||
|
||||
# Test empty sample
|
||||
test_config(2, "../data/dataset/testImageNetError/Sample1_empty/train")
|
||||
test_config(2, "../data/dataset/testImageNetError/Sample1_empty/train", 3)
|
||||
# Test corrupt sample
|
||||
test_config(3, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train")
|
||||
test_config(3, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train", 6)
|
||||
# Test text sample, instead of image sample
|
||||
test_config(1, "../data/dataset/testImageNetError/Sample3_text/train")
|
||||
test_config(1, "../data/dataset/testImageNetError/Sample3_text/train", 3)
|
||||
|
||||
|
||||
def test_imagefolder_error_sample_mapop():
|
||||
|
@ -1061,12 +1060,11 @@ def test_imagefolder_error_sample_mapop():
|
|||
"""
|
||||
|
||||
def test_config(my_seed, my_error_sample_data_file):
|
||||
# Set configuration
|
||||
# Note: This test depends on the seed value. An expected exception is not raised with some other seed values.
|
||||
# Set configuration since default Random Sampler is used
|
||||
original_seed = config_get_set_seed(my_seed)
|
||||
|
||||
# For ImageFolderDataset, use decode default (of False)
|
||||
data3 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1)
|
||||
data3 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None)
|
||||
# Add map op to the pipeline. Use C++ implemented ops
|
||||
data3 = data3.map(operations=[vision.Decode(),
|
||||
vision.HorizontalFlip()],
|
||||
|
@ -1077,7 +1075,7 @@ def test_imagefolder_error_sample_mapop():
|
|||
assert "map operation: [Decode] failed" in str(error_info.value)
|
||||
|
||||
# For ImageFolderDataset, use decode default (of False)
|
||||
data4 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1)
|
||||
data4 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None)
|
||||
# Add map op to the pipeline. Use Python implemented ops
|
||||
data4 = data4.map(operations=[vision.Decode(to_pil=True),
|
||||
vision.RandomHorizontalFlip(0.7)],
|
||||
|
|
|
@ -0,0 +1,675 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test Map operation's handling of rows with errors
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms as data_trans
|
||||
import mindspore.dataset.vision as vision
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.core.config import ErrorSamplesMode
|
||||
from util import config_get_set_seed
|
||||
|
||||
# Need to run all these tests in separate processes since we are modifying a config flag
|
||||
pytestmark = pytest.mark.forked
|
||||
|
||||
# Set global variable
|
||||
TOTAL_SIZE = 100
|
||||
|
||||
|
||||
def my_generator(ds_size):
|
||||
def generator_func():
|
||||
for i in range(ds_size):
|
||||
yield i
|
||||
|
||||
return generator_func
|
||||
|
||||
|
||||
def raise_none(x):
|
||||
return x
|
||||
|
||||
|
||||
def raise_all(x):
|
||||
raise ZeroDivisionError
|
||||
|
||||
|
||||
def raise_first(x):
|
||||
if x == 0:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_first_10(x):
|
||||
if x < 10:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_first_100(x):
|
||||
if x < 100:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_first_101(x):
|
||||
if x < 101:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_first_n(x):
|
||||
if x < TOTAL_SIZE // 2 - 2:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_first_m(x):
|
||||
if x < TOTAL_SIZE // 2 + 2:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_all_but_last(x):
|
||||
if x < TOTAL_SIZE - 1:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_all_but_first(x):
|
||||
if x > 0:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_last_n(x):
|
||||
if x > TOTAL_SIZE // 2 + 2:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_last_m(x):
|
||||
if x > TOTAL_SIZE // 2 - 2:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_all_odds(x):
|
||||
if x % 2 != 0:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def raise_all_3_remainders(x):
|
||||
if x % 3 != 0:
|
||||
raise ZeroDivisionError
|
||||
return x
|
||||
|
||||
|
||||
def run_replace_test(transforms, dataset_size, num_parallel_workers, python_multiprocessing, expected=None, epochs=1):
|
||||
""" Function to run test replace error samples mode based on input configuration. """
|
||||
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
|
||||
data1 = data1.map(operations=transforms,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=python_multiprocessing)
|
||||
|
||||
global TOTAL_SIZE
|
||||
TOTAL_SIZE = dataset_size
|
||||
|
||||
itr = data1.create_dict_iterator(num_epochs=epochs, output_numpy=True)
|
||||
count = 0
|
||||
result = []
|
||||
for _ in range(epochs):
|
||||
for _, data in enumerate(itr):
|
||||
count += 1
|
||||
if expected is not None:
|
||||
result.append(data["data"].item(0))
|
||||
assert count == dataset_size * epochs
|
||||
if expected is not None:
|
||||
assert result == expected
|
||||
|
||||
|
||||
def run_skip_test(transforms, dataset_size, num_parallel_workers, python_multiprocessing, expected=None, epochs=1):
|
||||
""" Function to run test skip error samples mode based on input configuration. """
|
||||
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
|
||||
data1 = data1.map(operations=transforms,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=python_multiprocessing)
|
||||
|
||||
global TOTAL_SIZE
|
||||
TOTAL_SIZE = dataset_size
|
||||
|
||||
itr = data1.create_dict_iterator(num_epochs=epochs, output_numpy=True)
|
||||
count = 0
|
||||
result = []
|
||||
for _ in range(epochs):
|
||||
for _, data in enumerate(itr):
|
||||
count += 1
|
||||
if expected is not None:
|
||||
result.append(data["data"].item(0))
|
||||
|
||||
if expected is not None:
|
||||
assert count == len(expected) * epochs
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_map_replace_errors_failure():
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Simple replace tests of data pipeline with error rows in Map operation
|
||||
Expectation: Exceptions are raise due to numerous error samples
|
||||
"""
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
|
||||
|
||||
# failure cases:
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
run_replace_test(raise_all, 1, 1, False)
|
||||
assert "All data is garbage" in str(error_info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
run_replace_test(raise_all, 10, 1, False)
|
||||
assert "All data is garbage" in str(error_info.value)
|
||||
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
|
||||
|
||||
def test_map_replace_errors_success1():
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Simple replace tests of data pipeline with various number of error rows in different indexes
|
||||
Expectation: Data pipeline replaces error rows successfully
|
||||
"""
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
|
||||
|
||||
# no error rows
|
||||
run_replace_test(raise_none, 10, 1, False, list(range(10)))
|
||||
run_replace_test(raise_none, 100, 1, False, list(range(100)))
|
||||
run_replace_test(raise_none, 1000, 1, False, list(range(1000)))
|
||||
|
||||
# 1 error row in the beginning of dataset
|
||||
run_replace_test(raise_first, 2, 1, False, [1, 1])
|
||||
run_replace_test(raise_first, 3, 1, False, [1, 2, 1])
|
||||
run_replace_test(raise_first, 10, 1, False, list(range(1, 10)) + [1])
|
||||
run_replace_test(raise_first, 16, 1, False, list(range(1, 16)) + [1])
|
||||
run_replace_test(raise_first, 17, 1, False, list(range(1, 17)) + [1])
|
||||
run_replace_test(raise_first, 20, 1, False, list(range(1, 17)) + [1] + list(range(17, 20)))
|
||||
run_replace_test(raise_first, 100, 1, False, list(range(1, 17)) + [1] + list(range(17, 100)))
|
||||
|
||||
# multiple error rows in beginning of dataset
|
||||
run_replace_test(raise_first_10, 11, 1, False, [10] * 11)
|
||||
run_replace_test(raise_first_10, 12, 1, False, [10, 11] * 6)
|
||||
run_replace_test(raise_first_10, 20, 1, False, list(range(10, 20)) * 2)
|
||||
run_replace_test(raise_first_10, 30, 1, False,
|
||||
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] +
|
||||
[10, 26, 27, 11, 28, 29, 12, 13, 14, 15, 16, 17, 18, 19])
|
||||
run_replace_test(raise_first_100, 1000, 1, False)
|
||||
run_replace_test(raise_first_n, 20, 1, False, list(range(8, 20)) + list(range(8, 16))) # ~first half (n < half)
|
||||
run_replace_test(raise_first_n, 40, 1, False, [18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] +
|
||||
[18, 34, 35, 19, 36, 37, 20, 38, 39] +
|
||||
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35,
|
||||
37]) # ~first half (n < half)
|
||||
run_replace_test(raise_first_n, 100, 1, False) # ~first half (n < half)
|
||||
run_replace_test(raise_first_m, 100, 1, False) # ~first half (m > half)
|
||||
run_replace_test(raise_all_but_last, 2, 1, False, [1, 1])
|
||||
run_replace_test(raise_all_but_last, 3, 1, False, [2, 2, 2])
|
||||
run_replace_test(raise_all_but_last, 4, 1, False, [3] * 4)
|
||||
run_replace_test(raise_all_but_last, 16, 1, False, [15] * 16)
|
||||
run_replace_test(raise_all_but_last, 100, 1, False, [99] * 100)
|
||||
run_replace_test(raise_all_but_first, 10, 1, False, [0] * 10)
|
||||
run_replace_test(raise_all_but_first, 100, 1, False, [0] * 100)
|
||||
|
||||
# error rows in the end of dataset
|
||||
run_replace_test(raise_last_n, 10, 1, False, list(range(0, 8)) + [0, 1])
|
||||
run_replace_test(raise_last_n, 20, 1, False, list(range(0, 13)) + list(range(0, 7)))
|
||||
run_replace_test(raise_last_n, 40, 1, False, list(range(0, 23)) + list(range(0, 16)) + [0])
|
||||
run_replace_test(raise_last_n, 100, 1, False)
|
||||
run_replace_test(raise_last_m, 100, 1, False)
|
||||
|
||||
# error rows in different places
|
||||
run_replace_test(raise_all_odds, 10, 1, False, [0, 2, 4, 6, 8] * 2)
|
||||
run_replace_test(raise_all_odds, 40, 1, False, [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] +
|
||||
[0, 32, 2, 34, 4, 36, 6, 38, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38])
|
||||
run_replace_test(raise_all_odds, 100, 1, False)
|
||||
run_replace_test(raise_all_3_remainders, 12, 1, False, [0, 3, 6, 9] * 3)
|
||||
run_replace_test(raise_all_3_remainders, 100, 1, False)
|
||||
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_mp", (False, True))
|
||||
def test_map_replace_errors_success2(my_mp):
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Simple replace tests of data pipeline with error rows in different settings
|
||||
Expectation: Data pipeline replaces error rows successfully
|
||||
"""
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
|
||||
# Check if python_multiprocessing is to be enabled
|
||||
if my_mp:
|
||||
# Reduce memory required by disabling the shared memory optimization
|
||||
mem_original = ds.config.get_enable_shared_mem()
|
||||
ds.config.set_enable_shared_mem(False)
|
||||
|
||||
# multiple pyfuncs
|
||||
run_replace_test([raise_last_n, raise_last_n, raise_first_m, raise_first_n], 100, 1, my_mp)
|
||||
run_replace_test([raise_last_n, raise_last_n, raise_first_n], 100, 2, my_mp) # n<50
|
||||
|
||||
# parallel workers
|
||||
run_replace_test(raise_first, 100, 2, my_mp)
|
||||
run_replace_test(raise_last_n, 100, 2, my_mp)
|
||||
run_replace_test(raise_first_n, 100, 4, my_mp)
|
||||
run_replace_test(raise_all_odds, 100, 2, my_mp)
|
||||
run_replace_test(raise_all_odds, 100, 3, my_mp)
|
||||
run_replace_test(raise_all_3_remainders, 100, 5, my_mp)
|
||||
|
||||
# multiple epochs
|
||||
run_replace_test(raise_last_m, 100, 1, my_mp, epochs=3)
|
||||
run_replace_test(raise_all_odds, 100, 3, my_mp, epochs=3)
|
||||
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
if my_mp:
|
||||
# Restore configuration for shared memory
|
||||
ds.config.set_enable_shared_mem(mem_original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_num_workers, my_mp",
|
||||
[(1, False), (4, False), (3, True)])
|
||||
def test_map_replace_errors_success3(my_num_workers, my_mp):
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Simple replace tests of data pipeline with error rows in different pipelines
|
||||
Expectation: Data pipeline replaces error rows successfully
|
||||
"""
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
|
||||
# Check if python_multiprocessing is to be enabled
|
||||
if my_mp:
|
||||
# Reduce memory required by disabling the shared memory optimization
|
||||
mem_original = ds.config.get_enable_shared_mem()
|
||||
ds.config.set_enable_shared_mem(False)
|
||||
|
||||
dataset_size = 100
|
||||
global TOTAL_SIZE
|
||||
TOTAL_SIZE = dataset_size
|
||||
|
||||
# multiple maps
|
||||
transforms = [raise_all_but_last]
|
||||
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
|
||||
data1 = data1.map(operations=transforms, num_parallel_workers=my_num_workers, python_multiprocessing=my_mp)
|
||||
data1 = data1.map(operations=transforms, num_parallel_workers=my_num_workers, python_multiprocessing=my_mp)
|
||||
|
||||
count = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == dataset_size
|
||||
|
||||
# repeat op
|
||||
transforms = [raise_all_but_first]
|
||||
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
|
||||
data1 = data1.map(operations=transforms, num_parallel_workers=my_num_workers, python_multiprocessing=my_mp)
|
||||
data1 = data1.repeat(3)
|
||||
|
||||
count = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == dataset_size * 3
|
||||
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
if my_mp:
|
||||
# Restore configuration for shared memory
|
||||
ds.config.set_enable_shared_mem(mem_original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_num_workers, my_mp",
|
||||
[(1, False), (3, False), (2, True)])
|
||||
def test_map_skip_errors_success1(my_num_workers, my_mp):
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Simple skip tests of data pipeline with various number of error rows in different indexes
|
||||
Expectation: Data pipeline replaces error rows successfully
|
||||
"""
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.SKIP)
|
||||
# Check if python_multiprocessing is to be enabled
|
||||
if my_mp:
|
||||
# Reduce memory required by disabling the shared memory optimization
|
||||
mem_original = ds.config.get_enable_shared_mem()
|
||||
ds.config.set_enable_shared_mem(False)
|
||||
|
||||
# no error rows
|
||||
run_skip_test(raise_none, 10, my_num_workers, my_mp, list(range(10)))
|
||||
run_skip_test(raise_none, 100, my_num_workers, my_mp, list(range(100)))
|
||||
run_skip_test(raise_none, 1000, my_num_workers, my_mp, list(range(1000)))
|
||||
|
||||
# 1 error row in the beginning of dataset
|
||||
run_skip_test(raise_first, 2, my_num_workers, my_mp, [1])
|
||||
run_skip_test(raise_first, 3, my_num_workers, my_mp, [1, 2])
|
||||
run_skip_test(raise_first, 10, my_num_workers, my_mp, list(range(1, 10)))
|
||||
run_skip_test(raise_first, 16, my_num_workers, my_mp, list(range(1, 16)))
|
||||
run_skip_test(raise_first, 17, my_num_workers, my_mp, list(range(1, 17)))
|
||||
run_skip_test(raise_first, 20, my_num_workers, my_mp, list(range(1, 20)))
|
||||
run_skip_test(raise_first, 100, my_num_workers, my_mp, list(range(1, 100)))
|
||||
|
||||
# multiple error rows in beginning of dataset
|
||||
run_skip_test(raise_first_10, 11, my_num_workers, my_mp, [10])
|
||||
run_skip_test(raise_first_10, 12, my_num_workers, my_mp, [10, 11])
|
||||
run_skip_test(raise_first_10, 20, my_num_workers, my_mp, list(range(10, 20)))
|
||||
run_skip_test(raise_first_10, 30, my_num_workers, my_mp, list(range(10, 30)))
|
||||
run_skip_test(raise_first_100, 250, my_num_workers, my_mp, list(range(100, 250)))
|
||||
run_skip_test(raise_first_100, 1000, my_num_workers, my_mp, list(range(100, 1000)))
|
||||
run_skip_test(raise_first_n, 20, my_num_workers, my_mp, list(range(8, 20))) # ~first half (n < half)
|
||||
run_skip_test(raise_first_n, 40, my_num_workers, my_mp, list(range(18, 40))) # ~first half (n < half)
|
||||
run_skip_test(raise_first_n, 100, my_num_workers, my_mp, list(range(48, 100))) # ~first half (n < half)
|
||||
run_skip_test(raise_first_m, 100, my_num_workers, my_mp, list(range(52, 100))) # ~first half (m > half)
|
||||
run_skip_test(raise_all_but_last, 2, my_num_workers, my_mp, [1])
|
||||
run_skip_test(raise_all_but_last, 3, my_num_workers, my_mp, [2])
|
||||
run_skip_test(raise_all_but_last, 4, my_num_workers, my_mp, [3])
|
||||
run_skip_test(raise_all_but_last, 16, my_num_workers, my_mp, [15])
|
||||
run_skip_test(raise_all_but_last, 100, my_num_workers, my_mp, [99])
|
||||
run_skip_test(raise_all_but_first, 10, my_num_workers, my_mp, [0])
|
||||
run_skip_test(raise_all_but_first, 100, my_num_workers, my_mp, [0])
|
||||
|
||||
# error rows in the end of dataset
|
||||
run_skip_test(raise_last_n, 10, my_num_workers, my_mp, list(range(0, 8)))
|
||||
run_skip_test(raise_last_n, 20, my_num_workers, my_mp, list(range(0, 13)))
|
||||
run_skip_test(raise_last_n, 40, my_num_workers, my_mp, list(range(0, 23)))
|
||||
run_skip_test(raise_last_n, 100, my_num_workers, my_mp, list(range(0, 53)))
|
||||
run_skip_test(raise_last_m, 100, my_num_workers, my_mp, list(range(0, 49)))
|
||||
|
||||
# error rows in different places
|
||||
run_skip_test(raise_all_odds, 10, my_num_workers, my_mp, [0, 2, 4, 6, 8])
|
||||
run_skip_test(raise_all_odds, 40, my_num_workers, my_mp,
|
||||
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38])
|
||||
run_skip_test(raise_all_odds, 100, 1, False)
|
||||
run_skip_test(raise_all_3_remainders, 12, my_num_workers, my_mp, [0, 3, 6, 9])
|
||||
run_skip_test(raise_all_3_remainders, 100, 1, False)
|
||||
|
||||
# error rows in entire dataset
|
||||
run_skip_test(raise_all, 1, my_num_workers, my_mp, [])
|
||||
run_skip_test(raise_all, 3, my_num_workers, my_mp, [])
|
||||
run_skip_test(raise_all, 10, my_num_workers, my_mp, [])
|
||||
run_skip_test(raise_all, 100, my_num_workers, my_mp, [])
|
||||
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
if my_mp:
|
||||
# Restore configuration for shared memory
|
||||
ds.config.set_enable_shared_mem(mem_original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_error_samples_mode",
|
||||
(ErrorSamplesMode.RETURN, ErrorSamplesMode.REPLACE, ErrorSamplesMode.SKIP))
|
||||
def test_map_error_samples_imagefolder1_basic(my_error_samples_mode):
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Invoke set_error_samples_mode and test ImageFolderDataset pipeline with map op and sample errors.
|
||||
Expectation: The dataset is processed as expected.
|
||||
"""
|
||||
|
||||
def test_config(my_error_samples_mode, my_error_sample_data_file, my_num_classes, my_total_samples,
|
||||
my_unskipped_samples):
|
||||
# For ImageFolderDataset:
|
||||
# - use num_samples=None to read all samples
|
||||
# - use num_parallel_workers=1
|
||||
# - use shuffle=False which will result in sequential order of samples
|
||||
# - use decode default of False
|
||||
data3 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, num_parallel_workers=1,
|
||||
shuffle=False)
|
||||
# Use multiple map ops in pipeline.
|
||||
data3 = data3.map(operations=[data_trans.OneHot(my_num_classes)],
|
||||
input_columns=["label"],
|
||||
num_parallel_workers=1)
|
||||
# 2nd map op is in pipeline is Decode Op, which uses C++ implementation
|
||||
data3 = data3.map(operations=[vision.Decode()], input_columns=["image"], num_parallel_workers=1)
|
||||
data3 = data3.map(operations=[vision.Crop((0, 0), 32)], input_columns=["image"], num_parallel_workers=1)
|
||||
|
||||
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
|
||||
# Error samples are to be replaced
|
||||
count = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == my_total_samples
|
||||
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
|
||||
# Error samples are to be skipped
|
||||
count = 0
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == my_unskipped_samples
|
||||
else:
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "map operation: [Decode] failed" in str(error_info.value)
|
||||
|
||||
data4 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, num_parallel_workers=1,
|
||||
shuffle=False)
|
||||
# Use multiple map ops in pipeline.
|
||||
data4 = data4.map(operations=[data_trans.OneHot(my_num_classes)],
|
||||
input_columns=["label"],
|
||||
num_parallel_workers=1)
|
||||
# 2nd map op is in pipeline is Decode Op, which uses Python implementation
|
||||
data4 = data4.map(operations=[vision.Decode(to_pil=True),
|
||||
vision.RandomHorizontalFlip(0.7)],
|
||||
input_columns=["image"], num_parallel_workers=1)
|
||||
# Note: ToPIL op added so that Python implementation of RandomVerticalFlip is selected
|
||||
data4 = data4.map(operations=[vision.ToPIL(),
|
||||
vision.RandomVerticalFlip(0.6)],
|
||||
input_columns=["image"], num_parallel_workers=1)
|
||||
|
||||
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
|
||||
# Error samples are to be replaced
|
||||
count = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == my_total_samples
|
||||
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
|
||||
# Error samples are to be skipped
|
||||
count = 0
|
||||
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == my_unskipped_samples
|
||||
else:
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "map operation: [PyFunc] failed" in str(error_info.value)
|
||||
|
||||
# Set configuration for error_samples_mode
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(my_error_samples_mode)
|
||||
|
||||
# Test empty sample (which is first error sample when samples are read sequentially)
|
||||
test_config(my_error_samples_mode, "../data/dataset/testImageNetError/Sample1_empty/train", 1, 3, 2)
|
||||
# Test corrupt sample (which is a middle error sample when samples are read sequentially)
|
||||
test_config(my_error_samples_mode, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train", 3, 6, 5)
|
||||
# Test text sample, instead of image sample (which is a final error sample when samples are read sequentially)
|
||||
test_config(my_error_samples_mode, "../data/dataset/testImageNetError/Sample3_text/train", 1, 3, 2)
|
||||
|
||||
# Restore configuration for error_samples_mode
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_error_samples_mode, my_num_workers, my_mp",
|
||||
[(ErrorSamplesMode.RETURN, 3, False), (ErrorSamplesMode.RETURN, 4, True),
|
||||
(ErrorSamplesMode.REPLACE, 2, False), (ErrorSamplesMode.REPLACE, 3, True),
|
||||
(ErrorSamplesMode.SKIP, 3, False), (ErrorSamplesMode.SKIP, 2, True)])
|
||||
def test_map_error_samples_imagefolder2_parallel(my_error_samples_mode, my_num_workers, my_mp):
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Invoke set_error_samples_mode and test ImageFolderDataset pipeline with map op
|
||||
with num_parallel workers and python_multiprocess set plus sample errors.
|
||||
Expectation: The dataset is processed as expected.
|
||||
"""
|
||||
|
||||
def test_config(my_error_samples_mode, my_seed, my_error_sample_data_file, my_total_samples, my_unskipped_samples):
|
||||
# Set configuration since default Random Sampler is used
|
||||
original_seed = config_get_set_seed(my_seed)
|
||||
|
||||
# Create dataset pipeline which includes at least one sample error
|
||||
my_sampler = ds.RandomSampler(replacement=False, num_samples=None)
|
||||
data1 = ds.ImageFolderDataset(my_error_sample_data_file, sampler=my_sampler,
|
||||
num_parallel_workers=my_num_workers)
|
||||
# Add map op to the pipeline which will encounter error samples. Use Python implemented ops
|
||||
# Note: Decode is not the first op in the list of operations
|
||||
data1 = data1.map(operations=[(lambda x: x),
|
||||
vision.Decode(to_pil=True),
|
||||
(lambda y: y),
|
||||
vision.RandomVerticalFlip(0.8)],
|
||||
input_columns=["image"],
|
||||
num_parallel_workers=my_num_workers,
|
||||
python_multiprocessing=my_mp)
|
||||
|
||||
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
|
||||
# Error samples are to be replaced
|
||||
count = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == my_total_samples
|
||||
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
|
||||
# Error samples are to be skipped
|
||||
count = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == my_unskipped_samples
|
||||
else:
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "map operation: [PyFunc] failed" in str(error_info.value)
|
||||
|
||||
# Restore configuration
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
# Set configuration for error_samples_mode
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(my_error_samples_mode)
|
||||
|
||||
# Check if python_multiprocessing is to be enabled
|
||||
if my_mp:
|
||||
# Reduce memory required by disabling the shared memory optimization
|
||||
mem_original = ds.config.get_enable_shared_mem()
|
||||
ds.config.set_enable_shared_mem(False)
|
||||
|
||||
# Test empty sample
|
||||
test_config(my_error_samples_mode, 2, "../data/dataset/testImageNetError/Sample1_empty/train", 3, 2)
|
||||
# Test corrupt sample
|
||||
test_config(my_error_samples_mode, 3, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train", 6, 5)
|
||||
# Test text sample, instead of image sample
|
||||
test_config(my_error_samples_mode, 1, "../data/dataset/testImageNetError/Sample3_text/train", 3, 2)
|
||||
|
||||
# Restore configuration for error_samples_mode
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
|
||||
if my_mp:
|
||||
# Restore configuration for shared memory
|
||||
ds.config.set_enable_shared_mem(mem_original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_error_samples_mode, my_num_workers, my_mp",
|
||||
[(ErrorSamplesMode.RETURN, 4, False), (ErrorSamplesMode.RETURN, 3, True),
|
||||
(ErrorSamplesMode.REPLACE, 3, False), (ErrorSamplesMode.REPLACE, 4, True),
|
||||
(ErrorSamplesMode.SKIP, 5, False), (ErrorSamplesMode.SKIP, 3, True)])
|
||||
def test_map_error_samples_imagefolder3(my_error_samples_mode, my_num_workers, my_mp):
|
||||
"""
|
||||
Feature: Process Error Samples
|
||||
Description: Invoke set_error_samples_mode and test ImageFolderDataset pipeline with map op
|
||||
with num_parallel workers and python_multiprocess set plus multiple sample errors.
|
||||
Expectation: The dataset is processed as expected.
|
||||
"""
|
||||
|
||||
def test_config(my_error_samples_mode, my_seed, my_take, my_repeat, my_total_unskipped_rows):
|
||||
# Set configuration since default Random Sampler is used
|
||||
original_seed = config_get_set_seed(my_seed)
|
||||
|
||||
# Create dataset pipelines with multiple error samples
|
||||
data1 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample1_empty/train",
|
||||
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
|
||||
data2 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample2_corrupt1mid/train",
|
||||
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
|
||||
data3 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample3_text/train",
|
||||
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
|
||||
data4 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample4_corruptall/train",
|
||||
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
|
||||
# Concat the multiple datasets together
|
||||
datafinal = data4 + data1 + data2 + data3
|
||||
datafinal = datafinal.take(my_take).repeat(my_repeat)
|
||||
total_rows = my_take * my_repeat
|
||||
# Add map op to the pipeline. Use Python implemented ops
|
||||
datafinal = datafinal.map(operations=[vision.Decode(to_pil=True),
|
||||
vision.RandomHorizontalFlip(0.9)],
|
||||
input_columns=["image"],
|
||||
num_parallel_workers=my_num_workers,
|
||||
python_multiprocessing=my_mp)
|
||||
# Apply dataset ops
|
||||
datafinal = datafinal.shuffle(buffer_size=100)
|
||||
|
||||
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
|
||||
# Error samples are to be replaced
|
||||
count = 0
|
||||
for _ in datafinal.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
assert count == total_rows
|
||||
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
|
||||
# Error samples are to be skipped
|
||||
count = 0
|
||||
for _ in datafinal.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count += 1
|
||||
logger.info("Number of data in datafinal: {}".format(count))
|
||||
assert count == my_total_unskipped_rows
|
||||
else:
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in datafinal.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "map operation: [PyFunc] failed" in str(error_info.value)
|
||||
|
||||
# Restore configuration
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
# Set configuration for error_samples_mode
|
||||
error_samples_mode_original = ds.config.get_error_samples_mode()
|
||||
ds.config.set_error_samples_mode(my_error_samples_mode)
|
||||
|
||||
# Check if python_multiprocessing is to be enabled
|
||||
if my_mp:
|
||||
# Reduce memory required by disabling the shared memory optimization
|
||||
mem_original = ds.config.get_enable_shared_mem()
|
||||
ds.config.set_enable_shared_mem(False)
|
||||
|
||||
# Test different scenarios
|
||||
test_config(my_error_samples_mode, 5001, 15, 1, 9)
|
||||
test_config(my_error_samples_mode, 5002, 15, 2, 18)
|
||||
test_config(my_error_samples_mode, 5003, 10, 5, 27)
|
||||
test_config(my_error_samples_mode, 5004, 12, 4, 28)
|
||||
|
||||
# Restore configuration for error_samples_mode
|
||||
ds.config.set_error_samples_mode(error_samples_mode_original)
|
||||
|
||||
if my_mp:
|
||||
# Restore configuration for shared memory
|
||||
ds.config.set_enable_shared_mem(mem_original)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_map_replace_errors_failure()
|
||||
test_map_replace_errors_success1()
|
||||
test_map_replace_errors_success2(True)
|
||||
test_map_replace_errors_success3(3, False)
|
||||
test_map_skip_errors_success1(3, True)
|
||||
test_map_error_samples_imagefolder1_basic(ErrorSamplesMode.REPLACE)
|
||||
test_map_error_samples_imagefolder2_parallel(ErrorSamplesMode.REPLACE, 4, True)
|
||||
test_map_error_samples_imagefolder3(ErrorSamplesMode.SKIP, 3, True)
|
Loading…
Reference in New Issue