!4776 Introduce 2 extra ctrl flags to DataBuffer in dataset, address remaining cmts to PR4632

Merge pull request !4776 from ZiruiWu/map_callback_follow_up
This commit is contained in:
mindspore-ci-bot 2020-08-22 03:31:31 +08:00 committed by Gitee
commit 9d7250c483
15 changed files with 234 additions and 81 deletions

View File

@ -45,6 +45,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_seed", &ConfigManager::seed)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("get_callback_timeout", &ConfigManager::callback_timeout)
.def("set_callback_timeout", &ConfigManager::set_callback_timeout)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
}));

View File

@ -50,7 +50,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) {
// return Status::OK() if no begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
@ -69,7 +69,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
// return Status::OK() if no epoch_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
@ -89,7 +89,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
// return Status::OK() if no step_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
@ -108,7 +108,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) {
// return Status::OK() if no end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
@ -127,7 +127,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
// return Status::OK() if no epoch_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {
@ -147,7 +147,7 @@ Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
// return Status::OK() if no step_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
RETURN_IF_NOT_OK(op_->WaitForWorkers());
// Now do the actual callback
for (size_t ind : callback_inds) {

View File

@ -32,7 +32,7 @@ class DatasetOp;
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
class CallbackManager {
public:
/// CallbackManager default constructor. Init needs to be called before using the created instance.
/// \brief CallbackManager default constructor. Init needs to be called before using the created instance.
CallbackManager() : enabled_(false) {}
/// \brief

View File

@ -88,5 +88,8 @@ uint32_t ConfigManager::seed() const { return seed_; }
void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; }
void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; }
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
} // namespace dataset
} // namespace mindspore

View File

@ -116,9 +116,17 @@ class ConfigManager {
void set_monitor_sampling_interval(uint32_t interval);
// getter function
// @return The iterval of monitor sampling
// @return The interval of monitor sampling
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
// setter function
// @param timeout - The setting to apply to the config
void set_callback_timeout(uint32_t timeout);
// getter function
// @return The timeout DSWaitedCallback would wait for before raising an error
int32_t callback_timeout() const { return callback_timout_; }
private:
int32_t rows_per_buffer_{kCfgRowsPerBuffer};
int32_t num_parallel_workers_{kCfgParallelWorkers};
@ -126,8 +134,9 @@ class ConfigManager {
int32_t op_connector_size_{kCfgOpConnectorSize};
uint32_t seed_{kCfgDefaultSeed};
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval};
uint32_t callback_timout_{kCfgCallbackTimeout};
// Private helper function that taks a nlohmann json format and populates the settings
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info
Status FromJson(const nlohmann::json &j);
};

View File

@ -68,6 +68,7 @@ constexpr uint32_t kCfgWorkerConnectorSize = 16;
constexpr uint32_t kCfgOpConnectorSize = 16;
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
constexpr uint32_t kCfgMonitorSamplingInterval = 10;
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;

View File

@ -37,8 +37,10 @@ class DataBuffer {
// Buffer flags
enum BufferFlags : uint32_t {
kDeBFlagNone = 0,
kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg
kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg
kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg
kDeBFlagEOE = 1u << 1, // The buffer is an eoe end-of-epoch msg
kDeBFlagWait = 1u << 2, // The buffer is an control signal for workers to suspend operations
kDeBFlagQuit = 1u << 3 // The buffer is a control signal for workers to quit
};
// Name: Constructor #1
@ -64,6 +66,10 @@ class DataBuffer {
bool eoe() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOE)); }
bool wait() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagWait)); }
bool quit() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagQuit)); }
// Simple getter funcs
int32_t id() const { return buffer_id_; }

View File

@ -363,10 +363,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
/// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not
/// needed. Only parallelOp needs to override this function.
/// They would automatically wait on the QueueList when they are done.
/// \return Status
virtual Status PauseFromMaster() { return Status::OK(); }
virtual Status WaitForWorkers() { return Status::OK(); }
protected:
/// \brief Removes a parent operator from this operator

View File

@ -166,7 +166,7 @@ Status MapOp::operator()() {
// init callback
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
Status rc = local_queues_.Register(tree_->AllTasks());
RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
if (rc.IsError()) {
TaskManager::FindMe()->Post();
return rc;
@ -205,23 +205,29 @@ Status MapOp::operator()() {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}
// send the eoe buffer to worker
// reset epoch_step when a new epoch is about to start
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
ep_step = 0;
}
// Propagate the eoe buffer to worker
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
UpdateRepeatAndEpochCounter();
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}
// the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
// handle eof logic
// End() is commented out because it might never be called due to the lack of EOF when EpochCtrl is -1
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
// Handle eof logic, this code might never be reached if epoch_ctrl = -1.
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
// Quit all workers, this code might never be reached if EpochCtrl is -1.
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
auto quit = std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagQuit));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(quit)));
}
return Status::OK();
}
@ -242,26 +248,27 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
// rather than use the base-class defaults.
while (true) {
// handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received
if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) {
// when worker receives the signal from master thread, it increments a atomic int
// the last guy who increments the counter, wakes up master thread
if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set();
// this will block the worker until master thread gives it a new work
// Handle special logic where buffer carries a ctrl flag.
if (in_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) {
if (in_buffer->wait()) {
// When worker receives the signal from master thread, it increments a atomic int
// The last guy who increments the counter, wakes up master thread
if (++num_workers_paused_ == num_workers_) {
wait_for_workers_post_.Set();
}
// This will block the worker until master thread gives it a new work
} else if (in_buffer->eoe()) {
// Calling base class EoeReceived to forward eoe buffer.
RETURN_IF_NOT_OK(EoeReceived(worker_id));
} else if (in_buffer->eof()) {
// Calling base class EofReceived to forward eof buffer.
RETURN_IF_NOT_OK(EofReceived(worker_id));
} else if (in_buffer->quit()) {
break;
}
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
continue;
} else if (in_buffer->eoe()) {
// Calling base class EoeReceived to forward eoe buffer.
RETURN_IF_NOT_OK(EoeReceived(worker_id));
// Fetch next data buffer and map job list
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
continue;
} else if (in_buffer->eof()) {
// Calling base class EofReceived to forward eof buffer.
RETURN_IF_NOT_OK(EofReceived(worker_id));
break;
}
CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer.");
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
@ -299,9 +306,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
// Variable to keep the result after executing the job.
std::vector<TensorRow> result_table;
// Executing the list of jobs
// Executing the list of jobs.
for (size_t i = 0; i < job_list.size(); i++) {
// Execute MapJob.
// Execute MapWorkerJob.
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
if (i + 1 < job_list.size()) {
@ -311,8 +318,7 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
// Sanity check a row in result_table
if (!result_table.empty() && out_columns_.size() != result_table[0].size()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Result of a tensorOp doesn't match output column names");
RETURN_STATUS_UNEXPECTED("Result of a tensorOp doesn't match output column names");
}
// Merging the data processed by job (result_table) with the data that are not used.
@ -386,7 +392,7 @@ Status MapOp::InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_
// columns from child are correct
RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map));
// initialize keep_input_columns, true means to keep the column.
// Initialize keep_input_columns, true means to keep the column.
keep_input_columns_.resize(col_name_id_map->size(), true);
for (const auto &col_name : in_columns_) {
int32_t missed = (*col_name_id_map)[col_name];
@ -449,18 +455,18 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_base<MapOp>(), modified);
}
Status MapOp::PauseFromMaster() {
Status MapOp::WaitForWorkers() {
// reset num_paused workers to 0
num_workers_paused_ = 0;
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
// a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause.
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone))));
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagWait))));
}
// wait until all workers are done processing their work in local_queue_
RETURN_IF_NOT_OK(master_pause_wp_.Wait());
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
// clear the WaitPost for the next Wait()
master_pause_wp_.Clear();
wait_for_workers_post_.Clear();
return Status::OK();
}
} // namespace dataset

View File

@ -228,10 +228,10 @@ class MapOp : public ParallelOp {
// Indices of the columns to process.
std::vector<size_t> to_process_indices_;
// wait post used to perform the pausing logic in MapOp
WaitPost master_pause_wp_;
// Wait post used to perform the pausing logic in MapOp
WaitPost wait_for_workers_post_;
// count number of workers that have signaled master
// Count number of workers that have signaled master
std::atomic_int num_workers_paused_;
// Private function for worker/thread to loop continuously. It comprises the main
@ -272,7 +272,7 @@ class MapOp : public ParallelOp {
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
// who does the increment wakes up the master.
// @return - Status
Status PauseFromMaster() override;
Status WaitForWorkers() override;
};
} // namespace dataset
} // namespace mindspore

View File

@ -34,7 +34,7 @@ class Semaphore {
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status P();
/// \brief Increment the internal counter. Wakeup on of the waiters if any.
/// \brief Increment the internal counter. Wake up on of the waiters if any.
void V();
/// \brief Peek the internal value
/// \return The internal value

View File

@ -18,6 +18,7 @@ Python callback class
import threading
from mindspore._c_dataengine import PyDSCallback
from mindspore.train.callback import Callback
import mindspore.dataset as ds
from .validators import check_callback
@ -170,7 +171,6 @@ class WaitedDSCallback(Callback, DSCallback):
"""
self.epoch_run_context = run_context
self.epoch_event.set()
self.epoch_event.clear()
def ds_epoch_begin(self, ds_run_context):
"""
@ -180,10 +180,12 @@ class WaitedDSCallback(Callback, DSCallback):
ds_run_context: Include some information of the pipeline.
"""
if ds_run_context.cur_epoch_num > 1:
if self.epoch_run_context is None:
self.epoch_event.wait()
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
self.epoch_event.clear()
if not success:
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)")
# by the time this thread wakes up, self.epoch_run_context is already available
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
self.epoch_run_context = None
def step_end(self, run_context):
"""
@ -194,7 +196,6 @@ class WaitedDSCallback(Callback, DSCallback):
"""
self.step_run_context = run_context
self.step_event.set()
self.step_event.clear()
def ds_step_begin(self, ds_run_context):
"""
@ -204,10 +205,12 @@ class WaitedDSCallback(Callback, DSCallback):
ds_run_context: Include some information of the pipeline.
"""
if ds_run_context.cur_step_num > self.step_size:
if self.step_run_context is None:
self.step_event.wait()
success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
self.step_event.clear()
if not success:
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)")
# by the time this thread wakes up, self.epoch_run_context is already available
self.sync_step_begin(self.step_run_context, ds_run_context)
self.step_run_context = None
def create_runtime_obj(self):
"""

View File

@ -157,6 +157,38 @@ def get_monitor_sampling_interval():
return _config.get_monitor_sampling_interval()
def set_callback_timeout(timeout):
"""
Set the default timeout (in seconds) for DSWaitedCallback.
In case of a deadlock, the wait function will exit after the timeout period.
Args:
timeout (int): timeout(s) to be used to end teh wait in DSWaitedCallback in case of a deadlock.
Raises:
ValueError: If timeout is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new timout value.
>>> ds.config.set_callback_timeout(100)
"""
if timeout <= 0 or timeout > INT32_MAX:
raise ValueError("timeout given is not within the required range.")
_config.set_callback_timeout(timeout)
def get_callback_timeout():
"""
Get the default timeout for DSWaitedCallback.
In case of a deadlock, the wait function will exit after the timeout period.
Returns:
Int, the duration in seconds
"""
return _config.get_callback_timeout()
def __str__():
"""
String representation of the configurations.

View File

@ -57,7 +57,7 @@ class TestCallback : public DSCallback {
begin_(true),
epoch_begin_(true),
step_begin_(true),
end_(true),
end_(false),
epoch_end_(true),
step_end_(true) {
all_names_.reserve(32);
@ -145,7 +145,6 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
std::shared_ptr<DSCallback> cb1 = tst_cb;
tst_cb->end_ = false; // don't do the end for now due to a timing issue
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
@ -193,7 +192,6 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
std::shared_ptr<DSCallback> cb1 = tst_cb;
tst_cb->end_ = false; // don't do the end for now due to a timing issue
// config leaf_op, use random_data to avoid I/O
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
@ -247,7 +245,6 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
Status rc;
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
std::shared_ptr<DSCallback> cb1 = tst_cb;
tst_cb->end_ = false;
// turn off the epochs
tst_cb->epoch_begin_ = false;
tst_cb->epoch_end_ = false;

View File

@ -29,7 +29,7 @@ import mindspore.nn as nn
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class MyDSCallback(DSCallback):
class BaseCallback(DSCallback):
def __init__(self, step_size=1, events=None, cb_id=0):
super().__init__(step_size)
self.events = events
@ -49,25 +49,36 @@ class MyDSCallback(DSCallback):
else:
self.events.append((event, [self.cb_id]))
class Begin(BaseCallback):
def ds_begin(self, ds_run_context):
self.append("begin", ds_run_context)
def ds_end(self, ds_run_context):
self.append("end", ds_run_context)
class EpochBegin(BaseCallback):
def ds_epoch_begin(self, ds_run_context):
self.append("epoch_begin", ds_run_context)
class EpochEnd(BaseCallback):
def ds_epoch_end(self, ds_run_context):
self.append("epoch_end", ds_run_context)
class StepBegin(BaseCallback):
def ds_step_begin(self, ds_run_context):
self.append("step_begin", ds_run_context)
class StepEnd(BaseCallback):
def ds_step_end(self, ds_run_context):
self.append("step_end", ds_run_context)
class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd):
pass
def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
events = []
cb_id = list(range(map_num))
@ -98,7 +109,12 @@ def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
data = data.map(operations=(lambda x: x), callbacks=my_cb)
if repeat != 1:
data = data.repeat(repeat)
if repeat % 2 == 0 and repeat != 2:
data = data.repeat(2)
data = data.map(operations=(lambda x: x))
data = data.repeat(repeat // 2)
else:
data = data.repeat(repeat)
itr = data.create_tuple_iterator(num_epochs=epochs)
for _ in range(epochs):
for _ in itr:
@ -201,11 +217,10 @@ def test_callbacks_all_2cbs():
build_test_case_2cbs(4, 4)
def test_callbacks_2maps():
def skip_test_callbacks_2maps():
logger.info("test_callbacks_2maps")
# This test case is skipped because in rare cases (25 out 1000) it might fail
build_test_case_2maps(5, 10)
build_test_case_2maps(6, 9)
@ -243,8 +258,8 @@ class Net(nn.Cell):
return x
def test_train_non_sink():
logger.info("test_train_non_sink")
def test_callbacks_non_sink():
logger.info("test_callbacks_non_sink")
events = []
my_cb1 = MyWaitedCallback(events, 1)
@ -267,8 +282,8 @@ def test_train_non_sink():
assert events == expected_synced_events
def test_train_batch_size2():
logger.info("test_train_batch_size2")
def test_callbacks_non_sink_batch_size2():
logger.info("test_callbacks_non_sink_batch_size2")
events = []
my_cb1 = MyWaitedCallback(events, 2)
@ -291,6 +306,27 @@ def test_train_batch_size2():
assert events == expected_synced_events
def test_callbacks_non_sink_mismatch_size():
logger.info("test_callbacks_non_sink_mismatch_size")
default_timeout = ds.config.get_callback_timeout()
ds.config.set_callback_timeout(1)
events = []
my_cb1 = MyWaitedCallback(events, 2)
my_cb2 = MyMSCallback(events)
arr = [1, 2, 3, 4]
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
data = data.batch(3)
net = Net()
model = Model(net)
with pytest.raises(Exception) as err:
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value)
ds.config.set_callback_timeout(default_timeout)
def test_callbacks_validations():
logger.info("test_callbacks_validations")
@ -318,7 +354,7 @@ def test_callbacks_validations():
assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
def test_callback_sink_simulation():
def test_callbacks_sink_simulation():
logger.info("test_callback_sink_simulation")
events = []
@ -353,13 +389,72 @@ def test_callbacks_repeat():
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4)
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8)
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16)
def test_callbacks_exceptions():
logger.info("test_callbacks_exceptions")
class BadCB(DSCallback):
def ds_begin(self, ds_run_context):
raise RuntimeError("Bad begin")
with pytest.raises(Exception) as err:
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
data = data.map(operations=(lambda x: x), callbacks=BadCB())
for _ in data:
pass
assert "RuntimeError: Bad begin" in str(err.value)
def test_callbacks_one_cb():
logger.info("test_callbacks_one_cb")
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
events1 = []
events2 = []
events3 = []
my_begin = Begin(events=events1, cb_id=1)
my_epoch_begin = EpochBegin(events=events2, cb_id=2)
my_epoch_end = EpochEnd(events=events3, cb_id=3)
my_step_begin = StepBegin(events=events3, cb_id=3)
my_step_end = StepEnd(events=events2, cb_id=2)
data = data.map(operations=(lambda x: x), callbacks=my_begin)
data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end])
data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin])
itr = data.create_tuple_iterator()
for _ in range(2):
for _ in itr:
pass
expected_events1 = [('begin_0_0_0', [1])]
expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]),
('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]),
('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]),
('step_end_2_4_8', [2])]
expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]),
('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]),
('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]),
('epoch_end_2_4_8', [3])]
assert events1 == expected_events1
assert events2 == expected_events2
assert events3 == expected_events3
if __name__ == '__main__':
test_callbacks_all_methods()
skip_test_callbacks_2maps()
test_callbacks_all_2cbs()
test_callbacks_2maps()
test_callbacks_all_methods()
test_callbacks_exceptions()
test_callbacks_repeat()
test_callbacks_sink_simulation()
test_callbacks_validations()
test_callbacks_var_step_size()
test_train_batch_size2()
test_callback_sink_simulation()
test_callbacks_repeat()
test_callbacks_non_sink_batch_size2()
test_callbacks_non_sink()
test_callbacks_one_cb()
test_callbacks_non_sink_mismatch_size()