!46291 [MD] Process Error Samples Replace and Skip Support for Map op

Merge pull request !46291 from cathwong/ckw_error_samples
This commit is contained in:
i-robot 2022-12-01 23:47:15 +00:00 committed by Gitee
commit 80bd3c0f86
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 1073 additions and 39 deletions

View File

@ -79,6 +79,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_fast_recovery", &ConfigManager::fast_recovery)
.def("set_debug_mode", &ConfigManager::set_debug_mode)
.def("get_debug_mode", &ConfigManager::get_debug_mode)
.def("set_error_samples_mode", &ConfigManager::set_error_samples_mode)
.def("get_error_samples_mode", &ConfigManager::get_error_samples_mode)
.def("load", [](ConfigManager &c, const std::string &s) { THROW_IF_ERROR(c.LoadFile(s)); });
}));
@ -199,5 +201,13 @@ PYBIND_REGISTER(ImageReadMode, 0, ([](const py::module *m) {
.value("DE_IMAGE_READ_MODE_COLOR", ImageReadMode::kCOLOR)
.export_values();
}));
PYBIND_REGISTER(ErrorSamplesMode, 0, ([](const py::module *m) {
(void)py::enum_<ErrorSamplesMode>(*m, "ErrorSamplesMode", py::arithmetic())
.value("DE_ERROR_SAMPLES_MODE_RETURN", ErrorSamplesMode::kReturn)
.value("DE_ERROR_SAMPLES_MODE_REPLACE", ErrorSamplesMode::kReplace)
.value("DE_ERROR_SAMPLES_MODE_SKIP", ErrorSamplesMode::kSkip)
.export_values();
}));
} // namespace dataset
} // namespace mindspore

View File

@ -89,6 +89,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
set_op_connector_size(j.value("opConnectorSize", op_connector_size_));
set_seed(j.value("seed", seed_));
set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
set_fast_recovery(j.value("fast_recovery", fast_recovery_));
set_error_samples_mode(j.value("error_samples_mode", error_samples_mode_));
set_cache_host(j.value("cacheHost", cache_host_));
set_cache_port(j.value("cachePort", cache_port_));
set_num_connections(j.value("numConnections", num_connections_));

View File

@ -38,6 +38,7 @@ namespace mindspore {
namespace dataset {
const char kEmptyString[] = "";
const char kJsonExtension[] = ".json";
// The ConfigManager is a class for managing default values. When a user is constructing any objects
// in the framework, often they may choose to omit some settings instead of overriding them.
// This class manages some of the default values, for cases when the user does not manually specify
@ -285,7 +286,8 @@ class ConfigManager {
// setter function
// @notes User must also set the seed to be able to get same augmentations
// @notes Fast recovery can cause slightly different random augmentations than original run (default=true)
// @notes Fast recovery can cause slightly different random augmentations than original run
// (System default = true)
// @param fast_recovery - Set whether MD pipeline recovers fast in failover reset
void set_fast_recovery(const bool fast_recovery) { fast_recovery_ = fast_recovery; }
@ -301,6 +303,22 @@ class ConfigManager {
// @return - Flag to indicate whether the debug mode is on
bool get_debug_mode() const { return debug_mode_flag_; }
// setter function
// @param error_samples_mode - Set the method in which erroneous samples should be processed
// (System default = ErrorSamplesMode::kReturn)
// @notes For replacement of erroneous samples, MD will select a deterministic but "random" sample.
void set_error_samples_mode(const ErrorSamplesMode error_samples_mode) { error_samples_mode_ = error_samples_mode; }
// getter function
// @return - The method in which erroneous samples should be processed in a dataset pipeline
// @notes This method is used for external configuration API which returns integer type
int32_t get_error_samples_mode() const { return static_cast<int>(error_samples_mode_); }
// getter function
// @return - The method in which erroneous samples should be processed in a dataset pipeline
// @notes This method is used for internal processing, using enum type
ErrorSamplesMode error_samples_mode() const { return error_samples_mode_; }
private:
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info
@ -337,6 +355,7 @@ class ConfigManager {
bool dynamic_shape_{false};
bool fast_recovery_{true}; // Used for failover scenario to recover quickly or produce same augmentations
bool debug_mode_flag_{false}; // Indicator for debug mode
ErrorSamplesMode error_samples_mode_{ErrorSamplesMode::kReturn}; // The method to process erroneous samples
};
} // namespace dataset
} // namespace mindspore

View File

@ -41,7 +41,9 @@ class TensorRow {
kFlagEOE = 1u << 1, // The row is an eoe end-of-epoch msg
kFlagWait = 1u << 2, // The row is an control signal for workers to suspend operations
kFlagQuit = 1u << 3, // The row is a control signal for workers to quit
kFlagSkip = 1u << 4 // The row is a control signal for workers to skip this row
kFlagSkip = 1u << 4, // The row is a control signal for workers to skip this row
kFlagError = 1u << 5 // The row is an error row (needs to be replaced with another row or skipped, as per
// ErrorSamplesMode config)
};
// Type definitions
@ -235,6 +237,10 @@ class TensorRow {
return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagSkip));
}
bool error() const {
return static_cast<bool>(static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagError));
}
const TensorRowFlags Flags() { return tensor_row_flag_; }
explicit TensorRow(TensorRowFlags flag);

View File

@ -24,6 +24,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
@ -256,7 +257,23 @@ Status MapOp::WorkerCompute(const TensorRow &in_row, TensorRow *out_row,
for (size_t i = 0; i < job_list.size(); i++) {
RETURN_IF_INTERRUPTED();
// Execute MapWorkerJob.
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
Status rc = job_list[i]->Run(job_input_table, &result_table);
if (rc.IsError()) {
if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kReplace) {
MS_LOG(WARNING)
<< "Detected an erroneous sample in MindData Map operation, and will replace with a healthy sample: " +
rc.GetErrDescription();
*out_row = TensorRow(TensorRow::kFlagError);
return Status::OK();
} else if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kSkip) {
MS_LOG(WARNING) << "Detected an erroneous sample in MindData Map operation, and will skip this sample: " +
rc.GetErrDescription();
*out_row = TensorRow(TensorRow::kFlagError);
return Status::OK();
} else {
return rc;
}
}
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
if (i + 1 < job_list.size()) {
job_input_table = std::move(result_table);

View File

@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
#include <algorithm>
#include <deque>
#include <map>
#include <memory>
#include <string>
@ -30,6 +31,7 @@
namespace mindspore {
namespace dataset {
constexpr int64_t kCachedRowsSize = 16;
class ExecutionTree;
@ -162,16 +164,155 @@ class ParallelOp : public DatasetOp {
return Status::OK();
}
class RowHandlingStrategy {
public:
explicit RowHandlingStrategy(ParallelOp *op) : op_(op) {}
virtual Status HandleHealthyRow(const TensorRow &row) {
++this->op_->ep_step_;
++this->op_->total_step_;
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
return this->op_->out_connector_->Add(std::move(row));
}
virtual Status HandleErrorRow(const TensorRow &row) = 0;
virtual Status HandleEOE(const TensorRow &row) {
this->op_->current_repeats_++;
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if (this->op_->current_repeats_ % this->op_->GetOpNumRepeatsPerEpoch() == 0) {
this->op_->current_epochs_++;
RETURN_IF_NOT_OK(this->op_->callback_manager_.EpochEnd(
CallbackParam(this->op_->current_epochs_, this->op_->ep_step_, this->op_->total_step_)));
this->op_->ep_step_ = 0;
}
return op_->out_connector_->Add(std::move(row));
}
virtual Status HandleEOF(const TensorRow &row) {
RETURN_IF_NOT_OK(this->op_->callback_manager_.End(
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
return op_->out_connector_->Add(std::move(row));
}
protected:
ParallelOp *op_;
};
class ErrorStrategy : public RowHandlingStrategy {
public:
using RowHandlingStrategy::RowHandlingStrategy;
Status HandleErrorRow(const TensorRow &row) override {
return Status(StatusCode::kMDUnexpectedError,
"[Internal Error] Error row is detected in collector while Error strategy is set to error out!");
}
};
class SkipStrategy : public RowHandlingStrategy {
public:
using RowHandlingStrategy::RowHandlingStrategy;
Status HandleErrorRow(const TensorRow &row) override { return Status::OK(); }
};
class ReplaceStrategy : public RowHandlingStrategy {
public:
using RowHandlingStrategy::RowHandlingStrategy;
Status HandleHealthyRow(const TensorRow &row) override {
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
"[Internal Error] Number of cached rows is beyond the number set.");
if (backup_index_ < kCachedRowsSize - 1) { // cache has used row(s)
if (IsCacheFull()) {
// remove the last element from cache (a used row)
PopFromCache();
}
AddToCache(row);
} else { // cache is full of unused rows
if (missing_errors_ > 0) {
// send a cached row to next op and cache the current row
RETURN_IF_NOT_OK(AddFromCache());
missing_errors_--;
AddToCache(row);
}
}
// send the healthy row to next op
++this->op_->ep_step_;
++this->op_->total_step_;
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
return this->op_->out_connector_->Add(std::move(row));
}
Status HandleErrorRow(const TensorRow &row) override {
CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
"[Internal Error] Number of cached rows is beyond the number set.");
// cache is not full of unused rows
if (backup_index_ != kCachedRowsSize - 1) {
missing_errors_++;
return Status::OK();
}
// cache is full of unused rows and we have an error row
return AddFromCache();
}
Status HandleEOE(const TensorRow &row) {
CHECK_FAIL_RETURN_UNEXPECTED(missing_errors_ == 0 || !IsCacheEmpty(),
"All data is garbage and cannot be replaced.");
// send outstanding rows first and then send eoe
while (missing_errors_ > 0) {
RETURN_IF_NOT_OK(AddFromCache());
missing_errors_--;
}
return RowHandlingStrategy::HandleEOE(row);
}
private:
Status AddFromCache() {
CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() > 0, "Cannot add a row from cache since cache is empty!");
// Note: If backup_index_ is negative (error samples at the end of data),
// the modulo division with size_t will be result in 0, and backup_rows[0] accessed.
auto cached_row = backup_rows[backup_index_ % backup_rows.size()];
backup_index_--;
++this->op_->ep_step_;
++this->op_->total_step_;
RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(
CallbackParam(this->op_->current_epochs_ + 1, this->op_->ep_step_, this->op_->total_step_)));
return this->op_->out_connector_->Add(std::move(cached_row));
}
Status AddToCache(TensorRow row) {
CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() < kCachedRowsSize,
"[Internal Error] Inserting another row to cache while cache is already full.");
CHECK_FAIL_RETURN_UNEXPECTED(
backup_index_ < kCachedRowsSize - 1,
"[Internal Error] Inserting another row to cache while cache is already full of unused rows.");
backup_rows.push_front(row);
backup_index_++;
return Status::OK();
}
void PopFromCache() { backup_rows.pop_back(); }
bool IsCacheFull() const { return backup_rows.size() == kCachedRowsSize; }
bool IsCacheEmpty() const { return backup_rows.size() == 0; }
std::deque<TensorRow> backup_rows{}; // will hold a copy of some healthy rows collected (NOT error, skip, eoe, eof)
int32_t backup_index_{-1}; // index of the backup we should pick next time (can be negative if we run out of
// unused cached rows)
int32_t missing_errors_{0}; // the number of unaddressed error rows (that we need to send a replacement to output)
};
virtual Status Collector() {
TaskManager::FindMe()->Post();
// num_rows received, including eoe, num_step of current epoch
int64_t num_rows = 0, ep_step = 0, total_step = 0;
int32_t current_repeats = 0, current_epochs = 0;
// num_rows received, including eoe,
int64_t num_rows = 0;
current_repeats_ = 0;
current_epochs_ = 0;
SetStrategy();
// num_step of current epoch and the total
ep_step_ = 0, total_step_ = 0;
TensorRow row;
do {
RETURN_IF_NOT_OK(worker_out_queues_[static_cast<const int>(num_rows++ % num_workers_)]->PopFront(&row));
if (row.wait()) {
// When collector receives the signal from workere thread, it increments a atomic int
// When collector receives the signal from worker thread, it increments an atomic int
// If num_worker signals are received, wakes up the main thread
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
@ -179,23 +320,16 @@ class ParallelOp : public DatasetOp {
}
continue;
} else if (row.eoe()) {
current_repeats++;
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if (current_repeats % GetOpNumRepeatsPerEpoch() == 0) {
current_epochs++;
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(current_epochs, ep_step, total_step)));
ep_step = 0;
}
RETURN_IF_NOT_OK(strategy_->HandleEOE(row));
} else if (row.eof()) {
RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(current_epochs + 1, ep_step, total_step)));
RETURN_IF_NOT_OK(strategy_->HandleEOF(row));
} else if (row.skip()) {
continue;
} else if (row.error()) {
RETURN_IF_NOT_OK(strategy_->HandleErrorRow(row));
} else if (row.Flags() == TensorRow::TensorRowFlags::kFlagNone) {
++ep_step;
++total_step;
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(current_epochs + 1, ep_step, total_step)));
RETURN_IF_NOT_OK(strategy_->HandleHealthyRow(row));
}
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
} while (!row.eof());
return Status::OK();
}
@ -223,6 +357,17 @@ class ParallelOp : public DatasetOp {
public:
int32_t NumWorkers() override { return num_workers_; }
private:
void SetStrategy() {
if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kSkip) {
strategy_ = std::make_shared<SkipStrategy>(this);
} else if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kReplace) {
strategy_ = std::make_shared<ReplaceStrategy>(this);
} else {
strategy_ = std::make_shared<ErrorStrategy>(this);
}
}
protected:
std::atomic_int next_worker_id_;
@ -234,6 +379,13 @@ class ParallelOp : public DatasetOp {
QueueList<T> worker_in_queues_;
/// queues to hold the output from workers
QueueList<S> worker_out_queues_;
private:
std::shared_ptr<RowHandlingStrategy> strategy_;
int32_t ep_step_{0};
int32_t total_step_{0};
int32_t current_epochs_{0};
int32_t current_repeats_{0};
};
} // namespace dataset
} // namespace mindspore

View File

@ -291,6 +291,13 @@ enum class DATASET_API ResampleMethod {
kKaiserWindow = 1, ///< Resample audio by Kaiser window
};
/// \brief Possible configuration methods for processing error samples.
enum class DATASET_API ErrorSamplesMode {
kReturn = 0, ///< Erroneous sample results in error raised and returned
kReplace = 1, ///< Erroneous sample is replaced with an internally determined sample
kSkip = 2 ///< Erroneous sample is skipped
};
/// \brief Convenience function to check bitmask for a 32bit int
/// \param[in] bits a 32bit int to be tested
/// \param[in] bitMask a 32bit int representing bit mask

View File

@ -23,14 +23,14 @@ Common imported modules in corresponding API examples are as follows:
import mindspore.dataset as ds
"""
from __future__ import absolute_import
from enum import IntEnum
import os
import platform
import random
import numpy
import mindspore._c_dataengine as cde
from mindspore import log as logger
from mindspore.dataset.core.validator_helpers import replace_none
from mindspore.dataset.core.validator_helpers import replace_none, type_check
__all__ = ['set_sending_batches', 'load', '_init_device_info',
'set_seed', 'get_seed',
@ -46,8 +46,9 @@ __all__ = ['set_sending_batches', 'load', '_init_device_info',
'set_auto_offload', 'get_auto_offload',
'set_enable_watchdog', 'get_enable_watchdog',
'set_fast_recovery', 'get_fast_recovery',
'set_multiprocessing_timeout_interval', 'get_multiprocessing_timeout_interval',
'set_debug_mode', 'get_debug_mode']
'set_debug_mode', 'get_debug_mode',
'set_error_samples_mode', 'get_error_samples_mode', 'ErrorSamplesMode',
'set_multiprocessing_timeout_interval', 'get_multiprocessing_timeout_interval']
INT32_MAX = 2147483647
UINT32_MAX = 4294967295
@ -809,7 +810,7 @@ def set_fast_recovery(fast_recovery):
(yet with slightly different random augmentations).
Args:
fast_recovery (bool): Whether the dataset pipeline recovers in fast mode.
fast_recovery (bool): Whether the dataset pipeline recovers in fast mode. System default: True.
Raises:
TypeError: If `fast_recovery` is not a boolean data type.
@ -873,3 +874,82 @@ def get_debug_mode():
>>> debug_mode = ds.config.get_debug_mode()
"""
return _config.get_debug_mode()
class ErrorSamplesMode(IntEnum):
"""
An enumeration for `error_samples_mode` .
Possible enumeration values are: ErrorSamplesMode.RETURN, ErrorSamplesMode.REPLACE, ErrorSamplesMode.SKIP.
- ErrorSamplesMode.RETURN: means erroneous sample results in error raised and returned.
- rrorSamplesMode.REPLACE: means erroneous sample is replaced with an internally determined sample.
- ErrorSamplesMode.SKIP: means erroneous sample is skipped.
"""
RETURN = 0
REPLACE = 1
SKIP = 2
# Convert ErrorSamplesMode from Python enum format to CDE enum format
_PYTHON_TO_CDE_ERROR_SAMPLES_MODE = {
ErrorSamplesMode.RETURN: cde.ErrorSamplesMode.DE_ERROR_SAMPLES_MODE_RETURN,
ErrorSamplesMode.REPLACE: cde.ErrorSamplesMode.DE_ERROR_SAMPLES_MODE_REPLACE,
ErrorSamplesMode.SKIP: cde.ErrorSamplesMode.DE_ERROR_SAMPLES_MODE_SKIP
}
# Convert ErrorSamplesMode from CDE int format to Python enum format
_CDE_TO_PYTHON_ERROR_SAMPLES_MODE = {
0: ErrorSamplesMode.RETURN,
1: ErrorSamplesMode.REPLACE,
2: ErrorSamplesMode.SKIP
}
def set_error_samples_mode(error_samples_mode):
"""
Set the method in which erroneous samples should be processed in a dataset pipeline.
Note:
1. This error samples feature is only applicable to the Map operation in a dataset pipeline.
2. For replacement mode, a cache of internally determined samples will be used.
3. If skip mode is used in a distributed setting, beware to manually ensure the
number of valid samples are the same for each shard (otherwise one may encounter hangs).
One technique is to manually concat a dataset of all valid samples plus a
take operation for the number of skipped erroneous samples.
Args:
error_samples_mode (ErrorSamplesMode): The method in which erroneous samples should be processed in a dataset
pipeline. It can be any of [ErrorSamplesMode.RETURN, ErrorSamplesMode.REPLACE, ErrorSamplesMode.SKIP].
System default: ErrorSamplesMode.RETURN.
- ErrorSamplesMode.RETURN: means erroneous sample results in error raised and returned.
- ErrorSamplesMode.REPLACE: means erroneous sample is replaced with an internally determined sample.
- ErrorSamplesMode.SKIP: means erroneous sample is skipped.
Raises:
TypeError: If `error_samples_mode` is not of type ErrorSamplesMode.
Examples:
>>> ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.SKIP)
"""
type_check(error_samples_mode, (ErrorSamplesMode,), "error_samples_mode")
_config.set_error_samples_mode(_PYTHON_TO_CDE_ERROR_SAMPLES_MODE.get(error_samples_mode))
def get_error_samples_mode():
"""
Get the current configuration for method for processing erroneous samples in a dataset pipeline.
Returns:
ErrorSamplesMode, The method in which erroneous samples should be processed in a dataset pipeline.
- ErrorSamplesMode.RETURN: means erroneous sample results in error raised and returned.
- ErrorSamplesMode.REPLACE: means erroneous sample is replaced with an internally determined sample.
- ErrorSamplesMode.SKIP: means erroneous sample is skipped.
Examples:
>>> error_samples_mode = ds.config.get_error_samples_mode()
"""
return _CDE_TO_PYTHON_ERROR_SAMPLES_MODE.get(_config.get_error_samples_mode())

View File

@ -2764,7 +2764,7 @@ def _worker_loop(operations, pipe, seed=get_seed()):
pipe.worker_send(output_tensors)
except Exception:
pipe.worker_send(ExceptionHandler(where="in map(or batch) worker and execute Python function"))
return
# Do not return
def worker_target(operations, seed=get_seed()):

View File

@ -6,5 +6,7 @@
"opConnectorSize": 16,
"seed": 5489,
"monitorSamplingInterval": 15,
"debug_mode_flag": true
"fast_recovery": true,
"debug_mode_flag": true,
"error_samples_mode": 1
}

View File

@ -0,0 +1,16 @@
{
"datasetType": "IMAGENET",
"numRows": 3,
"columns": {
"image": {
"type": "uint8",
"rank": 1,
"t_impl": "cvmat"
},
"label" : {
"type": "uint32",
"rank": 0,
"t_impl" : "flex"
}
}
}

View File

@ -25,6 +25,7 @@ import mindspore.dataset as ds
import mindspore.dataset.engine.iterators as it
import mindspore.dataset.transforms
import mindspore.dataset.vision as vision
import mindspore.dataset.core.config as config
from mindspore import log as logger
from util import dataset_equal
@ -42,6 +43,7 @@ def config_error_func(config_interface, input_args, err_type, except_err_msg):
assert except_err_msg in err_msg
@pytest.mark.forked
def test_basic():
"""
Feature: Config
@ -55,6 +57,7 @@ def test_basic():
monitor_sampling_interval_original = ds.config.get_monitor_sampling_interval()
fast_recovery_original = ds.config.get_fast_recovery()
debug_mode = ds.config.get_debug_mode()
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.load('../data/dataset/declient.cfg')
@ -65,6 +68,7 @@ def test_basic():
assert ds.config.get_monitor_sampling_interval() == 15
assert ds.config.get_fast_recovery()
assert ds.config.get_debug_mode()
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.REPLACE
ds.config.set_num_parallel_workers(2)
# ds.config.set_worker_connector_size(3)
@ -73,6 +77,7 @@ def test_basic():
ds.config.set_monitor_sampling_interval(45)
ds.config.set_fast_recovery(False)
ds.config.set_debug_mode(False)
ds.config.set_error_samples_mode(config.ErrorSamplesMode.RETURN)
assert ds.config.get_num_parallel_workers() == 2
# assert ds.config.get_worker_connector_size() == 3
@ -81,6 +86,13 @@ def test_basic():
assert ds.config.get_monitor_sampling_interval() == 45
assert not ds.config.get_fast_recovery()
assert not ds.config.get_debug_mode()
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.RETURN
ds.config.set_fast_recovery(True)
ds.config.set_error_samples_mode(config.ErrorSamplesMode.SKIP)
assert ds.config.get_fast_recovery()
assert ds.config.get_error_samples_mode() == config.ErrorSamplesMode.SKIP
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
@ -89,6 +101,7 @@ def test_basic():
ds.config.set_monitor_sampling_interval(monitor_sampling_interval_original)
ds.config.set_fast_recovery(fast_recovery_original)
ds.config.set_debug_mode(debug_mode)
ds.config.set_error_samples_mode(error_samples_mode_original)
def test_get_seed():
@ -521,6 +534,7 @@ def test_fast_recovery():
assert "set_fast_recovery() missing 1 required positional argument: 'fast_recovery'" in str(error_info.value)
@pytest.mark.forked
def test_debug_mode():
"""
Feature: Test the debug mode setter/getter function
@ -548,6 +562,40 @@ def test_debug_mode():
ds.config.set_debug_mode(origin_debug_mode_flag)
def test_error_samples_mode():
"""
Feature: Test the get_error_samples_mode function
Description: This function only accepts ErrorSamplesMode enum values as input and outputs error otherwise
Expectation: For error input, error is raised as expected
"""
# set_error_samples_mode will raise TypeError if input is boolean
config_error_func(config.set_error_samples_mode, False, TypeError,
"is not of type [<enum 'ErrorSamplesMode'>]")
# set_error_samples_mode will raise TypeError if input is int
config_error_func(config.set_error_samples_mode, 1, TypeError,
"is not of type [<enum 'ErrorSamplesMode'>]")
# set_error_samples_mode will raise TypeError if input is a string
config_error_func(config.set_error_samples_mode, "Zero", TypeError,
"is not of type [<enum 'ErrorSamplesMode'>]")
# set_error_samples_mode will raise TypeError if input is a tuple
config_error_func(config.set_error_samples_mode, (1,), TypeError,
"is not of type [<enum 'ErrorSamplesMode'>]")
# set_error_samples_mode will raise TypeError if input is None
config_error_func(config.set_error_samples_mode, None, TypeError,
"is not of type [<enum 'ErrorSamplesMode'>]")
# set_error_samples_mode will raise TypeError if no input is provided
with pytest.raises(TypeError) as error_info:
config.set_error_samples_mode()
assert "set_error_samples_mode() missing 1 required positional argument: 'error_samples_mode'" in \
str(error_info.value)
# set_error_samples_mode will raise TypeError if too many parameters are provided
with pytest.raises(TypeError) as error_info:
config.set_error_samples_mode(config.ErrorSamplesMode.REPLACE, 10)
assert "set_error_samples_mode() takes 1 positional argument but 2 were given" in str(error_info.value)
if __name__ == '__main__':
test_basic()
test_get_seed()
@ -565,3 +613,4 @@ if __name__ == '__main__':
test_config_bool_type_error()
test_fast_recovery()
test_debug_mode()
test_error_samples_mode()

View File

@ -1022,20 +1022,19 @@ def test_imagefolder_error_sample_sourceop():
Expectation: The dataset is processed as expected.
"""
def test_config(my_seed, my_error_sample_data_file):
# Set configuration
# Note: This test depends on the seed value. An expected exception is not raised with some other seed values.
def test_config(my_seed, my_error_sample_data_file, my_total_samples):
# Set configuration since default Random Sampler is used
original_seed = config_get_set_seed(my_seed)
# For ImageFolderDataset, use decode=False
data1 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1, decode=False)
data1 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, decode=False)
count = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 3
assert count == my_total_samples
# For ImageFolderDataset, use decode=True
data2 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1, decode=True)
data2 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, decode=True)
with pytest.raises(RuntimeError) as error_info:
for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
@ -1045,11 +1044,11 @@ def test_imagefolder_error_sample_sourceop():
ds.config.set_seed(original_seed)
# Test empty sample
test_config(2, "../data/dataset/testImageNetError/Sample1_empty/train")
test_config(2, "../data/dataset/testImageNetError/Sample1_empty/train", 3)
# Test corrupt sample
test_config(3, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train")
test_config(3, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train", 6)
# Test text sample, instead of image sample
test_config(1, "../data/dataset/testImageNetError/Sample3_text/train")
test_config(1, "../data/dataset/testImageNetError/Sample3_text/train", 3)
def test_imagefolder_error_sample_mapop():
@ -1061,12 +1060,11 @@ def test_imagefolder_error_sample_mapop():
"""
def test_config(my_seed, my_error_sample_data_file):
# Set configuration
# Note: This test depends on the seed value. An expected exception is not raised with some other seed values.
# Set configuration since default Random Sampler is used
original_seed = config_get_set_seed(my_seed)
# For ImageFolderDataset, use decode default (of False)
data3 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1)
data3 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None)
# Add map op to the pipeline. Use C++ implemented ops
data3 = data3.map(operations=[vision.Decode(),
vision.HorizontalFlip()],
@ -1077,7 +1075,7 @@ def test_imagefolder_error_sample_mapop():
assert "map operation: [Decode] failed" in str(error_info.value)
# For ImageFolderDataset, use decode default (of False)
data4 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=3, num_parallel_workers=1)
data4 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None)
# Add map op to the pipeline. Use Python implemented ops
data4 = data4.map(operations=[vision.Decode(to_pil=True),
vision.RandomHorizontalFlip(0.7)],

View File

@ -0,0 +1,675 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Map operation's handling of rows with errors
"""
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms as data_trans
import mindspore.dataset.vision as vision
from mindspore import log as logger
from mindspore.dataset.core.config import ErrorSamplesMode
from util import config_get_set_seed
# Need to run all these tests in separate processes since we are modifying a config flag
pytestmark = pytest.mark.forked
# Set global variable
TOTAL_SIZE = 100
def my_generator(ds_size):
def generator_func():
for i in range(ds_size):
yield i
return generator_func
def raise_none(x):
return x
def raise_all(x):
raise ZeroDivisionError
def raise_first(x):
if x == 0:
raise ZeroDivisionError
return x
def raise_first_10(x):
if x < 10:
raise ZeroDivisionError
return x
def raise_first_100(x):
if x < 100:
raise ZeroDivisionError
return x
def raise_first_101(x):
if x < 101:
raise ZeroDivisionError
return x
def raise_first_n(x):
if x < TOTAL_SIZE // 2 - 2:
raise ZeroDivisionError
return x
def raise_first_m(x):
if x < TOTAL_SIZE // 2 + 2:
raise ZeroDivisionError
return x
def raise_all_but_last(x):
if x < TOTAL_SIZE - 1:
raise ZeroDivisionError
return x
def raise_all_but_first(x):
if x > 0:
raise ZeroDivisionError
return x
def raise_last_n(x):
if x > TOTAL_SIZE // 2 + 2:
raise ZeroDivisionError
return x
def raise_last_m(x):
if x > TOTAL_SIZE // 2 - 2:
raise ZeroDivisionError
return x
def raise_all_odds(x):
if x % 2 != 0:
raise ZeroDivisionError
return x
def raise_all_3_remainders(x):
if x % 3 != 0:
raise ZeroDivisionError
return x
def run_replace_test(transforms, dataset_size, num_parallel_workers, python_multiprocessing, expected=None, epochs=1):
""" Function to run test replace error samples mode based on input configuration. """
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
data1 = data1.map(operations=transforms,
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing)
global TOTAL_SIZE
TOTAL_SIZE = dataset_size
itr = data1.create_dict_iterator(num_epochs=epochs, output_numpy=True)
count = 0
result = []
for _ in range(epochs):
for _, data in enumerate(itr):
count += 1
if expected is not None:
result.append(data["data"].item(0))
assert count == dataset_size * epochs
if expected is not None:
assert result == expected
def run_skip_test(transforms, dataset_size, num_parallel_workers, python_multiprocessing, expected=None, epochs=1):
""" Function to run test skip error samples mode based on input configuration. """
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
data1 = data1.map(operations=transforms,
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing)
global TOTAL_SIZE
TOTAL_SIZE = dataset_size
itr = data1.create_dict_iterator(num_epochs=epochs, output_numpy=True)
count = 0
result = []
for _ in range(epochs):
for _, data in enumerate(itr):
count += 1
if expected is not None:
result.append(data["data"].item(0))
if expected is not None:
assert count == len(expected) * epochs
assert result == expected
def test_map_replace_errors_failure():
"""
Feature: Process Error Samples
Description: Simple replace tests of data pipeline with error rows in Map operation
Expectation: Exceptions are raise due to numerous error samples
"""
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
# failure cases:
with pytest.raises(RuntimeError) as error_info:
run_replace_test(raise_all, 1, 1, False)
assert "All data is garbage" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
run_replace_test(raise_all, 10, 1, False)
assert "All data is garbage" in str(error_info.value)
ds.config.set_error_samples_mode(error_samples_mode_original)
def test_map_replace_errors_success1():
"""
Feature: Process Error Samples
Description: Simple replace tests of data pipeline with various number of error rows in different indexes
Expectation: Data pipeline replaces error rows successfully
"""
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
# no error rows
run_replace_test(raise_none, 10, 1, False, list(range(10)))
run_replace_test(raise_none, 100, 1, False, list(range(100)))
run_replace_test(raise_none, 1000, 1, False, list(range(1000)))
# 1 error row in the beginning of dataset
run_replace_test(raise_first, 2, 1, False, [1, 1])
run_replace_test(raise_first, 3, 1, False, [1, 2, 1])
run_replace_test(raise_first, 10, 1, False, list(range(1, 10)) + [1])
run_replace_test(raise_first, 16, 1, False, list(range(1, 16)) + [1])
run_replace_test(raise_first, 17, 1, False, list(range(1, 17)) + [1])
run_replace_test(raise_first, 20, 1, False, list(range(1, 17)) + [1] + list(range(17, 20)))
run_replace_test(raise_first, 100, 1, False, list(range(1, 17)) + [1] + list(range(17, 100)))
# multiple error rows in beginning of dataset
run_replace_test(raise_first_10, 11, 1, False, [10] * 11)
run_replace_test(raise_first_10, 12, 1, False, [10, 11] * 6)
run_replace_test(raise_first_10, 20, 1, False, list(range(10, 20)) * 2)
run_replace_test(raise_first_10, 30, 1, False,
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] +
[10, 26, 27, 11, 28, 29, 12, 13, 14, 15, 16, 17, 18, 19])
run_replace_test(raise_first_100, 1000, 1, False)
run_replace_test(raise_first_n, 20, 1, False, list(range(8, 20)) + list(range(8, 16))) # ~first half (n < half)
run_replace_test(raise_first_n, 40, 1, False, [18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33] +
[18, 34, 35, 19, 36, 37, 20, 38, 39] +
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 35,
37]) # ~first half (n < half)
run_replace_test(raise_first_n, 100, 1, False) # ~first half (n < half)
run_replace_test(raise_first_m, 100, 1, False) # ~first half (m > half)
run_replace_test(raise_all_but_last, 2, 1, False, [1, 1])
run_replace_test(raise_all_but_last, 3, 1, False, [2, 2, 2])
run_replace_test(raise_all_but_last, 4, 1, False, [3] * 4)
run_replace_test(raise_all_but_last, 16, 1, False, [15] * 16)
run_replace_test(raise_all_but_last, 100, 1, False, [99] * 100)
run_replace_test(raise_all_but_first, 10, 1, False, [0] * 10)
run_replace_test(raise_all_but_first, 100, 1, False, [0] * 100)
# error rows in the end of dataset
run_replace_test(raise_last_n, 10, 1, False, list(range(0, 8)) + [0, 1])
run_replace_test(raise_last_n, 20, 1, False, list(range(0, 13)) + list(range(0, 7)))
run_replace_test(raise_last_n, 40, 1, False, list(range(0, 23)) + list(range(0, 16)) + [0])
run_replace_test(raise_last_n, 100, 1, False)
run_replace_test(raise_last_m, 100, 1, False)
# error rows in different places
run_replace_test(raise_all_odds, 10, 1, False, [0, 2, 4, 6, 8] * 2)
run_replace_test(raise_all_odds, 40, 1, False, [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] +
[0, 32, 2, 34, 4, 36, 6, 38, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38])
run_replace_test(raise_all_odds, 100, 1, False)
run_replace_test(raise_all_3_remainders, 12, 1, False, [0, 3, 6, 9] * 3)
run_replace_test(raise_all_3_remainders, 100, 1, False)
ds.config.set_error_samples_mode(error_samples_mode_original)
@pytest.mark.parametrize("my_mp", (False, True))
def test_map_replace_errors_success2(my_mp):
"""
Feature: Process Error Samples
Description: Simple replace tests of data pipeline with error rows in different settings
Expectation: Data pipeline replaces error rows successfully
"""
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
# Check if python_multiprocessing is to be enabled
if my_mp:
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# multiple pyfuncs
run_replace_test([raise_last_n, raise_last_n, raise_first_m, raise_first_n], 100, 1, my_mp)
run_replace_test([raise_last_n, raise_last_n, raise_first_n], 100, 2, my_mp) # n<50
# parallel workers
run_replace_test(raise_first, 100, 2, my_mp)
run_replace_test(raise_last_n, 100, 2, my_mp)
run_replace_test(raise_first_n, 100, 4, my_mp)
run_replace_test(raise_all_odds, 100, 2, my_mp)
run_replace_test(raise_all_odds, 100, 3, my_mp)
run_replace_test(raise_all_3_remainders, 100, 5, my_mp)
# multiple epochs
run_replace_test(raise_last_m, 100, 1, my_mp, epochs=3)
run_replace_test(raise_all_odds, 100, 3, my_mp, epochs=3)
ds.config.set_error_samples_mode(error_samples_mode_original)
if my_mp:
# Restore configuration for shared memory
ds.config.set_enable_shared_mem(mem_original)
@pytest.mark.parametrize("my_num_workers, my_mp",
[(1, False), (4, False), (3, True)])
def test_map_replace_errors_success3(my_num_workers, my_mp):
"""
Feature: Process Error Samples
Description: Simple replace tests of data pipeline with error rows in different pipelines
Expectation: Data pipeline replaces error rows successfully
"""
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.REPLACE)
# Check if python_multiprocessing is to be enabled
if my_mp:
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
dataset_size = 100
global TOTAL_SIZE
TOTAL_SIZE = dataset_size
# multiple maps
transforms = [raise_all_but_last]
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
data1 = data1.map(operations=transforms, num_parallel_workers=my_num_workers, python_multiprocessing=my_mp)
data1 = data1.map(operations=transforms, num_parallel_workers=my_num_workers, python_multiprocessing=my_mp)
count = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == dataset_size
# repeat op
transforms = [raise_all_but_first]
data1 = ds.GeneratorDataset(my_generator(dataset_size), ["data"])
data1 = data1.map(operations=transforms, num_parallel_workers=my_num_workers, python_multiprocessing=my_mp)
data1 = data1.repeat(3)
count = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == dataset_size * 3
ds.config.set_error_samples_mode(error_samples_mode_original)
if my_mp:
# Restore configuration for shared memory
ds.config.set_enable_shared_mem(mem_original)
@pytest.mark.parametrize("my_num_workers, my_mp",
[(1, False), (3, False), (2, True)])
def test_map_skip_errors_success1(my_num_workers, my_mp):
"""
Feature: Process Error Samples
Description: Simple skip tests of data pipeline with various number of error rows in different indexes
Expectation: Data pipeline replaces error rows successfully
"""
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(ds.config.ErrorSamplesMode.SKIP)
# Check if python_multiprocessing is to be enabled
if my_mp:
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# no error rows
run_skip_test(raise_none, 10, my_num_workers, my_mp, list(range(10)))
run_skip_test(raise_none, 100, my_num_workers, my_mp, list(range(100)))
run_skip_test(raise_none, 1000, my_num_workers, my_mp, list(range(1000)))
# 1 error row in the beginning of dataset
run_skip_test(raise_first, 2, my_num_workers, my_mp, [1])
run_skip_test(raise_first, 3, my_num_workers, my_mp, [1, 2])
run_skip_test(raise_first, 10, my_num_workers, my_mp, list(range(1, 10)))
run_skip_test(raise_first, 16, my_num_workers, my_mp, list(range(1, 16)))
run_skip_test(raise_first, 17, my_num_workers, my_mp, list(range(1, 17)))
run_skip_test(raise_first, 20, my_num_workers, my_mp, list(range(1, 20)))
run_skip_test(raise_first, 100, my_num_workers, my_mp, list(range(1, 100)))
# multiple error rows in beginning of dataset
run_skip_test(raise_first_10, 11, my_num_workers, my_mp, [10])
run_skip_test(raise_first_10, 12, my_num_workers, my_mp, [10, 11])
run_skip_test(raise_first_10, 20, my_num_workers, my_mp, list(range(10, 20)))
run_skip_test(raise_first_10, 30, my_num_workers, my_mp, list(range(10, 30)))
run_skip_test(raise_first_100, 250, my_num_workers, my_mp, list(range(100, 250)))
run_skip_test(raise_first_100, 1000, my_num_workers, my_mp, list(range(100, 1000)))
run_skip_test(raise_first_n, 20, my_num_workers, my_mp, list(range(8, 20))) # ~first half (n < half)
run_skip_test(raise_first_n, 40, my_num_workers, my_mp, list(range(18, 40))) # ~first half (n < half)
run_skip_test(raise_first_n, 100, my_num_workers, my_mp, list(range(48, 100))) # ~first half (n < half)
run_skip_test(raise_first_m, 100, my_num_workers, my_mp, list(range(52, 100))) # ~first half (m > half)
run_skip_test(raise_all_but_last, 2, my_num_workers, my_mp, [1])
run_skip_test(raise_all_but_last, 3, my_num_workers, my_mp, [2])
run_skip_test(raise_all_but_last, 4, my_num_workers, my_mp, [3])
run_skip_test(raise_all_but_last, 16, my_num_workers, my_mp, [15])
run_skip_test(raise_all_but_last, 100, my_num_workers, my_mp, [99])
run_skip_test(raise_all_but_first, 10, my_num_workers, my_mp, [0])
run_skip_test(raise_all_but_first, 100, my_num_workers, my_mp, [0])
# error rows in the end of dataset
run_skip_test(raise_last_n, 10, my_num_workers, my_mp, list(range(0, 8)))
run_skip_test(raise_last_n, 20, my_num_workers, my_mp, list(range(0, 13)))
run_skip_test(raise_last_n, 40, my_num_workers, my_mp, list(range(0, 23)))
run_skip_test(raise_last_n, 100, my_num_workers, my_mp, list(range(0, 53)))
run_skip_test(raise_last_m, 100, my_num_workers, my_mp, list(range(0, 49)))
# error rows in different places
run_skip_test(raise_all_odds, 10, my_num_workers, my_mp, [0, 2, 4, 6, 8])
run_skip_test(raise_all_odds, 40, my_num_workers, my_mp,
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38])
run_skip_test(raise_all_odds, 100, 1, False)
run_skip_test(raise_all_3_remainders, 12, my_num_workers, my_mp, [0, 3, 6, 9])
run_skip_test(raise_all_3_remainders, 100, 1, False)
# error rows in entire dataset
run_skip_test(raise_all, 1, my_num_workers, my_mp, [])
run_skip_test(raise_all, 3, my_num_workers, my_mp, [])
run_skip_test(raise_all, 10, my_num_workers, my_mp, [])
run_skip_test(raise_all, 100, my_num_workers, my_mp, [])
ds.config.set_error_samples_mode(error_samples_mode_original)
if my_mp:
# Restore configuration for shared memory
ds.config.set_enable_shared_mem(mem_original)
@pytest.mark.parametrize("my_error_samples_mode",
(ErrorSamplesMode.RETURN, ErrorSamplesMode.REPLACE, ErrorSamplesMode.SKIP))
def test_map_error_samples_imagefolder1_basic(my_error_samples_mode):
"""
Feature: Process Error Samples
Description: Invoke set_error_samples_mode and test ImageFolderDataset pipeline with map op and sample errors.
Expectation: The dataset is processed as expected.
"""
def test_config(my_error_samples_mode, my_error_sample_data_file, my_num_classes, my_total_samples,
my_unskipped_samples):
# For ImageFolderDataset:
# - use num_samples=None to read all samples
# - use num_parallel_workers=1
# - use shuffle=False which will result in sequential order of samples
# - use decode default of False
data3 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, num_parallel_workers=1,
shuffle=False)
# Use multiple map ops in pipeline.
data3 = data3.map(operations=[data_trans.OneHot(my_num_classes)],
input_columns=["label"],
num_parallel_workers=1)
# 2nd map op is in pipeline is Decode Op, which uses C++ implementation
data3 = data3.map(operations=[vision.Decode()], input_columns=["image"], num_parallel_workers=1)
data3 = data3.map(operations=[vision.Crop((0, 0), 32)], input_columns=["image"], num_parallel_workers=1)
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
# Error samples are to be replaced
count = 0
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == my_total_samples
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
# Error samples are to be skipped
count = 0
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == my_unskipped_samples
else:
with pytest.raises(RuntimeError) as error_info:
for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "map operation: [Decode] failed" in str(error_info.value)
data4 = ds.ImageFolderDataset(my_error_sample_data_file, num_samples=None, num_parallel_workers=1,
shuffle=False)
# Use multiple map ops in pipeline.
data4 = data4.map(operations=[data_trans.OneHot(my_num_classes)],
input_columns=["label"],
num_parallel_workers=1)
# 2nd map op is in pipeline is Decode Op, which uses Python implementation
data4 = data4.map(operations=[vision.Decode(to_pil=True),
vision.RandomHorizontalFlip(0.7)],
input_columns=["image"], num_parallel_workers=1)
# Note: ToPIL op added so that Python implementation of RandomVerticalFlip is selected
data4 = data4.map(operations=[vision.ToPIL(),
vision.RandomVerticalFlip(0.6)],
input_columns=["image"], num_parallel_workers=1)
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
# Error samples are to be replaced
count = 0
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == my_total_samples
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
# Error samples are to be skipped
count = 0
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == my_unskipped_samples
else:
with pytest.raises(RuntimeError) as error_info:
for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "map operation: [PyFunc] failed" in str(error_info.value)
# Set configuration for error_samples_mode
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(my_error_samples_mode)
# Test empty sample (which is first error sample when samples are read sequentially)
test_config(my_error_samples_mode, "../data/dataset/testImageNetError/Sample1_empty/train", 1, 3, 2)
# Test corrupt sample (which is a middle error sample when samples are read sequentially)
test_config(my_error_samples_mode, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train", 3, 6, 5)
# Test text sample, instead of image sample (which is a final error sample when samples are read sequentially)
test_config(my_error_samples_mode, "../data/dataset/testImageNetError/Sample3_text/train", 1, 3, 2)
# Restore configuration for error_samples_mode
ds.config.set_error_samples_mode(error_samples_mode_original)
@pytest.mark.parametrize("my_error_samples_mode, my_num_workers, my_mp",
[(ErrorSamplesMode.RETURN, 3, False), (ErrorSamplesMode.RETURN, 4, True),
(ErrorSamplesMode.REPLACE, 2, False), (ErrorSamplesMode.REPLACE, 3, True),
(ErrorSamplesMode.SKIP, 3, False), (ErrorSamplesMode.SKIP, 2, True)])
def test_map_error_samples_imagefolder2_parallel(my_error_samples_mode, my_num_workers, my_mp):
"""
Feature: Process Error Samples
Description: Invoke set_error_samples_mode and test ImageFolderDataset pipeline with map op
with num_parallel workers and python_multiprocess set plus sample errors.
Expectation: The dataset is processed as expected.
"""
def test_config(my_error_samples_mode, my_seed, my_error_sample_data_file, my_total_samples, my_unskipped_samples):
# Set configuration since default Random Sampler is used
original_seed = config_get_set_seed(my_seed)
# Create dataset pipeline which includes at least one sample error
my_sampler = ds.RandomSampler(replacement=False, num_samples=None)
data1 = ds.ImageFolderDataset(my_error_sample_data_file, sampler=my_sampler,
num_parallel_workers=my_num_workers)
# Add map op to the pipeline which will encounter error samples. Use Python implemented ops
# Note: Decode is not the first op in the list of operations
data1 = data1.map(operations=[(lambda x: x),
vision.Decode(to_pil=True),
(lambda y: y),
vision.RandomVerticalFlip(0.8)],
input_columns=["image"],
num_parallel_workers=my_num_workers,
python_multiprocessing=my_mp)
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
# Error samples are to be replaced
count = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == my_total_samples
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
# Error samples are to be skipped
count = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == my_unskipped_samples
else:
with pytest.raises(RuntimeError) as error_info:
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "map operation: [PyFunc] failed" in str(error_info.value)
# Restore configuration
ds.config.set_seed(original_seed)
# Set configuration for error_samples_mode
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(my_error_samples_mode)
# Check if python_multiprocessing is to be enabled
if my_mp:
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# Test empty sample
test_config(my_error_samples_mode, 2, "../data/dataset/testImageNetError/Sample1_empty/train", 3, 2)
# Test corrupt sample
test_config(my_error_samples_mode, 3, "../data/dataset/testImageNetError/Sample2_corrupt1mid/train", 6, 5)
# Test text sample, instead of image sample
test_config(my_error_samples_mode, 1, "../data/dataset/testImageNetError/Sample3_text/train", 3, 2)
# Restore configuration for error_samples_mode
ds.config.set_error_samples_mode(error_samples_mode_original)
if my_mp:
# Restore configuration for shared memory
ds.config.set_enable_shared_mem(mem_original)
@pytest.mark.parametrize("my_error_samples_mode, my_num_workers, my_mp",
[(ErrorSamplesMode.RETURN, 4, False), (ErrorSamplesMode.RETURN, 3, True),
(ErrorSamplesMode.REPLACE, 3, False), (ErrorSamplesMode.REPLACE, 4, True),
(ErrorSamplesMode.SKIP, 5, False), (ErrorSamplesMode.SKIP, 3, True)])
def test_map_error_samples_imagefolder3(my_error_samples_mode, my_num_workers, my_mp):
"""
Feature: Process Error Samples
Description: Invoke set_error_samples_mode and test ImageFolderDataset pipeline with map op
with num_parallel workers and python_multiprocess set plus multiple sample errors.
Expectation: The dataset is processed as expected.
"""
def test_config(my_error_samples_mode, my_seed, my_take, my_repeat, my_total_unskipped_rows):
# Set configuration since default Random Sampler is used
original_seed = config_get_set_seed(my_seed)
# Create dataset pipelines with multiple error samples
data1 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample1_empty/train",
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
data2 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample2_corrupt1mid/train",
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
data3 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample3_text/train",
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
data4 = ds.ImageFolderDataset("../data/dataset/testImageNetError/Sample4_corruptall/train",
num_samples=None, shuffle=True, num_parallel_workers=my_num_workers)
# Concat the multiple datasets together
datafinal = data4 + data1 + data2 + data3
datafinal = datafinal.take(my_take).repeat(my_repeat)
total_rows = my_take * my_repeat
# Add map op to the pipeline. Use Python implemented ops
datafinal = datafinal.map(operations=[vision.Decode(to_pil=True),
vision.RandomHorizontalFlip(0.9)],
input_columns=["image"],
num_parallel_workers=my_num_workers,
python_multiprocessing=my_mp)
# Apply dataset ops
datafinal = datafinal.shuffle(buffer_size=100)
if my_error_samples_mode == ErrorSamplesMode.REPLACE:
# Error samples are to be replaced
count = 0
for _ in datafinal.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == total_rows
elif my_error_samples_mode == ErrorSamplesMode.SKIP:
# Error samples are to be skipped
count = 0
for _ in datafinal.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
logger.info("Number of data in datafinal: {}".format(count))
assert count == my_total_unskipped_rows
else:
with pytest.raises(RuntimeError) as error_info:
for _ in datafinal.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "map operation: [PyFunc] failed" in str(error_info.value)
# Restore configuration
ds.config.set_seed(original_seed)
# Set configuration for error_samples_mode
error_samples_mode_original = ds.config.get_error_samples_mode()
ds.config.set_error_samples_mode(my_error_samples_mode)
# Check if python_multiprocessing is to be enabled
if my_mp:
# Reduce memory required by disabling the shared memory optimization
mem_original = ds.config.get_enable_shared_mem()
ds.config.set_enable_shared_mem(False)
# Test different scenarios
test_config(my_error_samples_mode, 5001, 15, 1, 9)
test_config(my_error_samples_mode, 5002, 15, 2, 18)
test_config(my_error_samples_mode, 5003, 10, 5, 27)
test_config(my_error_samples_mode, 5004, 12, 4, 28)
# Restore configuration for error_samples_mode
ds.config.set_error_samples_mode(error_samples_mode_original)
if my_mp:
# Restore configuration for shared memory
ds.config.set_enable_shared_mem(mem_original)
if __name__ == '__main__':
test_map_replace_errors_failure()
test_map_replace_errors_success1()
test_map_replace_errors_success2(True)
test_map_replace_errors_success3(3, False)
test_map_skip_errors_success1(3, True)
test_map_error_samples_imagefolder1_basic(ErrorSamplesMode.REPLACE)
test_map_error_samples_imagefolder2_parallel(ErrorSamplesMode.REPLACE, 4, True)
test_map_error_samples_imagefolder3(ErrorSamplesMode.SKIP, 3, True)