From 74c1e6da605f3cb26af2c3cd82f93a806e36b444 Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Wed, 19 Aug 2020 15:41:22 -0400 Subject: [PATCH] introducing pause and quit flags to DataBuffer fix review cmts fix ci fix Ci fixci address ci ci - add timeout - add more test cases fix CI address review cmts --- .../python/bindings/dataset/core/bindings.cc | 2 + .../dataset/callback/callback_manager.cc | 12 +- .../dataset/callback/callback_manager.h | 2 +- .../minddata/dataset/core/config_manager.cc | 3 + .../minddata/dataset/core/config_manager.h | 13 +- .../ccsrc/minddata/dataset/core/constants.h | 1 + .../minddata/dataset/engine/data_buffer.h | 10 +- .../dataset/engine/datasetops/dataset_op.h | 5 +- .../engine/datasetops/map_op/map_op.cc | 72 +++++----- .../dataset/engine/datasetops/map_op/map_op.h | 8 +- .../ccsrc/minddata/dataset/util/semaphore.h | 2 +- mindspore/dataset/callback/ds_callback.py | 19 +-- mindspore/dataset/core/config.py | 32 +++++ tests/ut/cpp/dataset/callback_test.cc | 5 +- tests/ut/python/dataset/test_callbacks.py | 129 +++++++++++++++--- 15 files changed, 234 insertions(+), 81 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc index 860dd3b3ba5..5c8c9b75496 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc @@ -45,6 +45,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) { .def("get_op_connector_size", &ConfigManager::op_connector_size) .def("get_seed", &ConfigManager::seed) .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) + .def("get_callback_timeout", &ConfigManager::callback_timeout) + .def("set_callback_timeout", &ConfigManager::set_callback_timeout) .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); })); diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc index 948ae8d4520..5ab0011561e 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc @@ -50,7 +50,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) { // return Status::OK() if no begin is needed RETURN_OK_IF_TRUE(callback_inds.empty()); - RETURN_IF_NOT_OK(op_->PauseFromMaster()); + RETURN_IF_NOT_OK(op_->WaitForWorkers()); // Now do the actual callback for (size_t ind : callback_inds) { @@ -69,7 +69,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { // return Status::OK() if no epoch_begin is needed RETURN_OK_IF_TRUE(callback_inds.empty()); - RETURN_IF_NOT_OK(op_->PauseFromMaster()); + RETURN_IF_NOT_OK(op_->WaitForWorkers()); // Now do the actual callback for (size_t ind : callback_inds) { @@ -89,7 +89,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) { // return Status::OK() if no step_begin is needed RETURN_OK_IF_TRUE(callback_inds.empty()); - RETURN_IF_NOT_OK(op_->PauseFromMaster()); + RETURN_IF_NOT_OK(op_->WaitForWorkers()); // Now do the actual callback for (size_t ind : callback_inds) { @@ -108,7 +108,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) { // return Status::OK() if no end is needed RETURN_OK_IF_TRUE(callback_inds.empty()); - RETURN_IF_NOT_OK(op_->PauseFromMaster()); + RETURN_IF_NOT_OK(op_->WaitForWorkers()); // Now do the actual callback for (size_t ind : callback_inds) { @@ -127,7 +127,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { // return Status::OK() if no epoch_end is needed RETURN_OK_IF_TRUE(callback_inds.empty()); - RETURN_IF_NOT_OK(op_->PauseFromMaster()); + RETURN_IF_NOT_OK(op_->WaitForWorkers()); // Now do the actual callback for (size_t ind : callback_inds) { @@ -147,7 +147,7 @@ Status CallbackManager::StepEnd(const CallbackParam &cb_param) { // return Status::OK() if no step_end is needed RETURN_OK_IF_TRUE(callback_inds.empty()); - RETURN_IF_NOT_OK(op_->PauseFromMaster()); + RETURN_IF_NOT_OK(op_->WaitForWorkers()); // Now do the actual callback for (size_t ind : callback_inds) { diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h index 73ffda0fd47..9fadb7bf1c1 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h @@ -32,7 +32,7 @@ class DatasetOp; /// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this. class CallbackManager { public: - /// CallbackManager default constructor. Init needs to be called before using the created instance. + /// \brief CallbackManager default constructor. Init needs to be called before using the created instance. CallbackManager() : enabled_(false) {} /// \brief diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc index e1fc7f29ba7..f505a6b187e 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -88,5 +88,8 @@ uint32_t ConfigManager::seed() const { return seed_; } void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; } + +void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h index 4d25c472e01..eb154b8f440 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.h +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -116,9 +116,17 @@ class ConfigManager { void set_monitor_sampling_interval(uint32_t interval); // getter function - // @return The iterval of monitor sampling + // @return The interval of monitor sampling int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } + // setter function + // @param timeout - The setting to apply to the config + void set_callback_timeout(uint32_t timeout); + + // getter function + // @return The timeout DSWaitedCallback would wait for before raising an error + int32_t callback_timeout() const { return callback_timout_; } + private: int32_t rows_per_buffer_{kCfgRowsPerBuffer}; int32_t num_parallel_workers_{kCfgParallelWorkers}; @@ -126,8 +134,9 @@ class ConfigManager { int32_t op_connector_size_{kCfgOpConnectorSize}; uint32_t seed_{kCfgDefaultSeed}; uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; + uint32_t callback_timout_{kCfgCallbackTimeout}; - // Private helper function that taks a nlohmann json format and populates the settings + // Private helper function that takes a nlohmann json format and populates the settings // @param j - The json nlohmann json info Status FromJson(const nlohmann::json &j); }; diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index d2ef2c14c91..8b7911e5f6e 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -68,6 +68,7 @@ constexpr uint32_t kCfgWorkerConnectorSize = 16; constexpr uint32_t kCfgOpConnectorSize = 16; constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; constexpr uint32_t kCfgMonitorSamplingInterval = 10; +constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) constexpr uint8_t kCVInvalidType = 255; diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h index 01f2b3a881b..5236446c12a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h @@ -37,8 +37,10 @@ class DataBuffer { // Buffer flags enum BufferFlags : uint32_t { kDeBFlagNone = 0, - kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg - kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg + kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg + kDeBFlagEOE = 1u << 1, // The buffer is an eoe end-of-epoch msg + kDeBFlagWait = 1u << 2, // The buffer is an control signal for workers to suspend operations + kDeBFlagQuit = 1u << 3 // The buffer is a control signal for workers to quit }; // Name: Constructor #1 @@ -64,6 +66,10 @@ class DataBuffer { bool eoe() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOE)); } + bool wait() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagWait)); } + + bool quit() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagQuit)); } + // Simple getter funcs int32_t id() const { return buffer_id_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h index 1e06049c87c..f99c0603d6a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -363,10 +363,9 @@ class DatasetOp : public std::enable_shared_from_this { /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp /// The expected behavior is this, when this function is invoked, this function will block until all the workers /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. - /// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not - /// needed. Only parallelOp needs to override this function. + /// They would automatically wait on the QueueList when they are done. /// \return Status - virtual Status PauseFromMaster() { return Status::OK(); } + virtual Status WaitForWorkers() { return Status::OK(); } protected: /// \brief Removes a parent operator from this operator diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index e7061b8ccc0..bc8cf22e87c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -166,7 +166,7 @@ Status MapOp::operator()() { // init callback RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); Status rc = local_queues_.Register(tree_->AllTasks()); - RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); if (rc.IsError()) { TaskManager::FindMe()->Post(); return rc; @@ -205,23 +205,29 @@ Status MapOp::operator()() { RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); } - // send the eoe buffer to worker - - // reset epoch_step when a new epoch is about to start + // check whether this is the end of a real epoch (not all eoe signals end of epoch) if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); ep_step = 0; } + // Propagate the eoe buffer to worker std::unique_ptr worker_job = std::make_unique(std::move(buff)); RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); UpdateRepeatAndEpochCounter(); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); } - // the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback - // RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step))); - // handle eof logic + // End() is commented out because it might never be called due to the lack of EOF when EpochCtrl is -1 + // RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step))); + // Handle eof logic, this code might never be reached if epoch_ctrl = -1. std::unique_ptr worker_job = std::make_unique(std::move(buff)); RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); + + // Quit all workers, this code might never be reached if EpochCtrl is -1. + for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { + auto quit = std::make_unique(std::make_unique(0, DataBuffer::kDeBFlagQuit)); + RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(quit))); + } + return Status::OK(); } @@ -242,26 +248,27 @@ Status MapOp::WorkerEntry(int32_t worker_id) { // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself // rather than use the base-class defaults. while (true) { - // handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received - if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) { - // when worker receives the signal from master thread, it increments a atomic int - // the last guy who increments the counter, wakes up master thread - if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set(); - // this will block the worker until master thread gives it a new work + // Handle special logic where buffer carries a ctrl flag. + if (in_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) { + if (in_buffer->wait()) { + // When worker receives the signal from master thread, it increments a atomic int + // The last guy who increments the counter, wakes up master thread + if (++num_workers_paused_ == num_workers_) { + wait_for_workers_post_.Set(); + } + // This will block the worker until master thread gives it a new work + } else if (in_buffer->eoe()) { + // Calling base class EoeReceived to forward eoe buffer. + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + } else if (in_buffer->eof()) { + // Calling base class EofReceived to forward eof buffer. + RETURN_IF_NOT_OK(EofReceived(worker_id)); + } else if (in_buffer->quit()) { + break; + } RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); continue; - } else if (in_buffer->eoe()) { - // Calling base class EoeReceived to forward eoe buffer. - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - // Fetch next data buffer and map job list - RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); - continue; - } else if (in_buffer->eof()) { - // Calling base class EofReceived to forward eof buffer. - RETURN_IF_NOT_OK(EofReceived(worker_id)); - break; } - CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer."); std::unique_ptr new_tensor_table(std::make_unique()); // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. @@ -299,9 +306,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl // Variable to keep the result after executing the job. std::vector result_table; - // Executing the list of jobs + // Executing the list of jobs. for (size_t i = 0; i < job_list.size(); i++) { - // Execute MapJob. + // Execute MapWorkerJob. RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table)); // Assign the processed data as an input for the next job processing, except for the last TensorOp in the list. if (i + 1 < job_list.size()) { @@ -311,8 +318,7 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl // Sanity check a row in result_table if (!result_table.empty() && out_columns_.size() != result_table[0].size()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Result of a tensorOp doesn't match output column names"); + RETURN_STATUS_UNEXPECTED("Result of a tensorOp doesn't match output column names"); } // Merging the data processed by job (result_table) with the data that are not used. @@ -386,7 +392,7 @@ Status MapOp::InitPrivateVariable(std::unordered_map *col_ // columns from child are correct RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); - // initialize keep_input_columns, true means to keep the column. + // Initialize keep_input_columns, true means to keep the column. keep_input_columns_.resize(col_name_id_map->size(), true); for (const auto &col_name : in_columns_) { int32_t missed = (*col_name_id_map)[col_name]; @@ -449,18 +455,18 @@ Status MapOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } -Status MapOp::PauseFromMaster() { +Status MapOp::WaitForWorkers() { // reset num_paused workers to 0 num_workers_paused_ = 0; for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { // a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause. RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add( - std::make_unique(std::make_unique(-1, DataBuffer::kDeBFlagNone)))); + std::make_unique(std::make_unique(0, DataBuffer::kDeBFlagWait)))); } // wait until all workers are done processing their work in local_queue_ - RETURN_IF_NOT_OK(master_pause_wp_.Wait()); + RETURN_IF_NOT_OK(wait_for_workers_post_.Wait()); // clear the WaitPost for the next Wait() - master_pause_wp_.Clear(); + wait_for_workers_post_.Clear(); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h index a66ee55fa9f..de99a2587ea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h @@ -228,10 +228,10 @@ class MapOp : public ParallelOp { // Indices of the columns to process. std::vector to_process_indices_; - // wait post used to perform the pausing logic in MapOp - WaitPost master_pause_wp_; + // Wait post used to perform the pausing logic in MapOp + WaitPost wait_for_workers_post_; - // count number of workers that have signaled master + // Count number of workers that have signaled master std::atomic_int num_workers_paused_; // Private function for worker/thread to loop continuously. It comprises the main @@ -272,7 +272,7 @@ class MapOp : public ParallelOp { // Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker // who does the increment wakes up the master. // @return - Status - Status PauseFromMaster() override; + Status WaitForWorkers() override; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.h b/mindspore/ccsrc/minddata/dataset/util/semaphore.h index e54516291b3..6d604d00438 100644 --- a/mindspore/ccsrc/minddata/dataset/util/semaphore.h +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.h @@ -34,7 +34,7 @@ class Semaphore { /// \brief Decrement the internal counter. Will be blocked if the value is 0. /// \return Error code. Can get interrupt. Status P(); - /// \brief Increment the internal counter. Wakeup on of the waiters if any. + /// \brief Increment the internal counter. Wake up on of the waiters if any. void V(); /// \brief Peek the internal value /// \return The internal value diff --git a/mindspore/dataset/callback/ds_callback.py b/mindspore/dataset/callback/ds_callback.py index e4b1b0140fd..52ffdaffbcb 100644 --- a/mindspore/dataset/callback/ds_callback.py +++ b/mindspore/dataset/callback/ds_callback.py @@ -18,6 +18,7 @@ Python callback class import threading from mindspore._c_dataengine import PyDSCallback from mindspore.train.callback import Callback +import mindspore.dataset as ds from .validators import check_callback @@ -170,7 +171,6 @@ class WaitedDSCallback(Callback, DSCallback): """ self.epoch_run_context = run_context self.epoch_event.set() - self.epoch_event.clear() def ds_epoch_begin(self, ds_run_context): """ @@ -180,10 +180,12 @@ class WaitedDSCallback(Callback, DSCallback): ds_run_context: Include some information of the pipeline. """ if ds_run_context.cur_epoch_num > 1: - if self.epoch_run_context is None: - self.epoch_event.wait() + success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) + self.epoch_event.clear() + if not success: + raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") + # by the time this thread wakes up, self.epoch_run_context is already available self.sync_epoch_begin(self.epoch_run_context, ds_run_context) - self.epoch_run_context = None def step_end(self, run_context): """ @@ -194,7 +196,6 @@ class WaitedDSCallback(Callback, DSCallback): """ self.step_run_context = run_context self.step_event.set() - self.step_event.clear() def ds_step_begin(self, ds_run_context): """ @@ -204,10 +205,12 @@ class WaitedDSCallback(Callback, DSCallback): ds_run_context: Include some information of the pipeline. """ if ds_run_context.cur_step_num > self.step_size: - if self.step_run_context is None: - self.step_event.wait() + success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) + self.step_event.clear() + if not success: + raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") + # by the time this thread wakes up, self.epoch_run_context is already available self.sync_step_begin(self.step_run_context, ds_run_context) - self.step_run_context = None def create_runtime_obj(self): """ diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index c863186d97b..7ebcadd302b 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -157,6 +157,38 @@ def get_monitor_sampling_interval(): return _config.get_monitor_sampling_interval() +def set_callback_timeout(timeout): + """ + Set the default timeout (in seconds) for DSWaitedCallback. + In case of a deadlock, the wait function will exit after the timeout period. + + Args: + timeout (int): timeout(s) to be used to end teh wait in DSWaitedCallback in case of a deadlock. + + Raises: + ValueError: If timeout is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new timout value. + >>> ds.config.set_callback_timeout(100) + """ + if timeout <= 0 or timeout > INT32_MAX: + raise ValueError("timeout given is not within the required range.") + _config.set_callback_timeout(timeout) + + +def get_callback_timeout(): + """ + Get the default timeout for DSWaitedCallback. + In case of a deadlock, the wait function will exit after the timeout period. + + Returns: + Int, the duration in seconds + """ + return _config.get_callback_timeout() + + def __str__(): """ String representation of the configurations. diff --git a/tests/ut/cpp/dataset/callback_test.cc b/tests/ut/cpp/dataset/callback_test.cc index 39671b38047..9509d081c46 100644 --- a/tests/ut/cpp/dataset/callback_test.cc +++ b/tests/ut/cpp/dataset/callback_test.cc @@ -57,7 +57,7 @@ class TestCallback : public DSCallback { begin_(true), epoch_begin_(true), step_begin_(true), - end_(true), + end_(false), epoch_end_(true), step_end_(true) { all_names_.reserve(32); @@ -145,7 +145,6 @@ TEST_F(MindDataTestCallback, TestBasicCallback) { Status rc; std::shared_ptr tst_cb = std::make_shared(64); std::shared_ptr cb1 = tst_cb; - tst_cb->end_ = false; // don't do the end for now due to a timing issue // config leaf_op, use random_data to avoid I/O std::unique_ptr schema = std::make_unique(); TensorShape shape({}); // empty shape is a 1-value scalar Tensor @@ -193,7 +192,6 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) { Status rc; std::shared_ptr tst_cb = std::make_shared(4); std::shared_ptr cb1 = tst_cb; - tst_cb->end_ = false; // don't do the end for now due to a timing issue // config leaf_op, use random_data to avoid I/O std::unique_ptr schema = std::make_unique(); TensorShape shape({}); // empty shape is a 1-value scalar Tensor @@ -247,7 +245,6 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { Status rc; std::shared_ptr tst_cb = std::make_shared(4); std::shared_ptr cb1 = tst_cb; - tst_cb->end_ = false; // turn off the epochs tst_cb->epoch_begin_ = false; tst_cb->epoch_end_ = false; diff --git a/tests/ut/python/dataset/test_callbacks.py b/tests/ut/python/dataset/test_callbacks.py index 2c56dd1bcbb..d75ab4fbe06 100644 --- a/tests/ut/python/dataset/test_callbacks.py +++ b/tests/ut/python/dataset/test_callbacks.py @@ -29,7 +29,7 @@ import mindspore.nn as nn context.set_context(mode=context.GRAPH_MODE, device_target="CPU") -class MyDSCallback(DSCallback): +class BaseCallback(DSCallback): def __init__(self, step_size=1, events=None, cb_id=0): super().__init__(step_size) self.events = events @@ -49,25 +49,36 @@ class MyDSCallback(DSCallback): else: self.events.append((event, [self.cb_id])) + +class Begin(BaseCallback): def ds_begin(self, ds_run_context): self.append("begin", ds_run_context) - def ds_end(self, ds_run_context): - self.append("end", ds_run_context) +class EpochBegin(BaseCallback): def ds_epoch_begin(self, ds_run_context): self.append("epoch_begin", ds_run_context) + +class EpochEnd(BaseCallback): def ds_epoch_end(self, ds_run_context): self.append("epoch_end", ds_run_context) + +class StepBegin(BaseCallback): def ds_step_begin(self, ds_run_context): self.append("step_begin", ds_run_context) + +class StepEnd(BaseCallback): def ds_step_end(self, ds_run_context): self.append("step_end", ds_run_context) +class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd): + pass + + def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1): events = [] cb_id = list(range(map_num)) @@ -98,7 +109,12 @@ def build_test_case_1cb(epochs, steps, step_size=1, repeat=1): data = data.map(operations=(lambda x: x), callbacks=my_cb) if repeat != 1: - data = data.repeat(repeat) + if repeat % 2 == 0 and repeat != 2: + data = data.repeat(2) + data = data.map(operations=(lambda x: x)) + data = data.repeat(repeat // 2) + else: + data = data.repeat(repeat) itr = data.create_tuple_iterator(num_epochs=epochs) for _ in range(epochs): for _ in itr: @@ -201,11 +217,10 @@ def test_callbacks_all_2cbs(): build_test_case_2cbs(4, 4) -def test_callbacks_2maps(): +def skip_test_callbacks_2maps(): logger.info("test_callbacks_2maps") - + # This test case is skipped because in rare cases (25 out 1000) it might fail build_test_case_2maps(5, 10) - build_test_case_2maps(6, 9) @@ -243,8 +258,8 @@ class Net(nn.Cell): return x -def test_train_non_sink(): - logger.info("test_train_non_sink") +def test_callbacks_non_sink(): + logger.info("test_callbacks_non_sink") events = [] my_cb1 = MyWaitedCallback(events, 1) @@ -267,8 +282,8 @@ def test_train_non_sink(): assert events == expected_synced_events -def test_train_batch_size2(): - logger.info("test_train_batch_size2") +def test_callbacks_non_sink_batch_size2(): + logger.info("test_callbacks_non_sink_batch_size2") events = [] my_cb1 = MyWaitedCallback(events, 2) @@ -291,6 +306,27 @@ def test_train_batch_size2(): assert events == expected_synced_events +def test_callbacks_non_sink_mismatch_size(): + logger.info("test_callbacks_non_sink_mismatch_size") + default_timeout = ds.config.get_callback_timeout() + ds.config.set_callback_timeout(1) + + events = [] + my_cb1 = MyWaitedCallback(events, 2) + my_cb2 = MyMSCallback(events) + arr = [1, 2, 3, 4] + data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=my_cb1) + data = data.batch(3) + net = Net() + model = Model(net) + with pytest.raises(Exception) as err: + model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) + assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value) + + ds.config.set_callback_timeout(default_timeout) + + def test_callbacks_validations(): logger.info("test_callbacks_validations") @@ -318,7 +354,7 @@ def test_callbacks_validations(): assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value) -def test_callback_sink_simulation(): +def test_callbacks_sink_simulation(): logger.info("test_callback_sink_simulation") events = [] @@ -353,13 +389,72 @@ def test_callbacks_repeat(): build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3) build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3) + build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) + build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4) + build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8) + build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16) + + +def test_callbacks_exceptions(): + logger.info("test_callbacks_exceptions") + + class BadCB(DSCallback): + def ds_begin(self, ds_run_context): + raise RuntimeError("Bad begin") + + with pytest.raises(Exception) as err: + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + data = data.map(operations=(lambda x: x), callbacks=BadCB()) + for _ in data: + pass + assert "RuntimeError: Bad begin" in str(err.value) + + +def test_callbacks_one_cb(): + logger.info("test_callbacks_one_cb") + + data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) + events1 = [] + events2 = [] + events3 = [] + my_begin = Begin(events=events1, cb_id=1) + my_epoch_begin = EpochBegin(events=events2, cb_id=2) + my_epoch_end = EpochEnd(events=events3, cb_id=3) + my_step_begin = StepBegin(events=events3, cb_id=3) + my_step_end = StepEnd(events=events2, cb_id=2) + + data = data.map(operations=(lambda x: x), callbacks=my_begin) + data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end]) + data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin]) + + itr = data.create_tuple_iterator() + for _ in range(2): + for _ in itr: + pass + expected_events1 = [('begin_0_0_0', [1])] + expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]), + ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]), + ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]), + ('step_end_2_4_8', [2])] + expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]), + ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]), + ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]), + ('epoch_end_2_4_8', [3])] + assert events1 == expected_events1 + assert events2 == expected_events2 + assert events3 == expected_events3 + if __name__ == '__main__': - test_callbacks_all_methods() + skip_test_callbacks_2maps() test_callbacks_all_2cbs() - test_callbacks_2maps() + test_callbacks_all_methods() + test_callbacks_exceptions() + test_callbacks_repeat() + test_callbacks_sink_simulation() test_callbacks_validations() test_callbacks_var_step_size() - test_train_batch_size2() - test_callback_sink_simulation() - test_callbacks_repeat() + test_callbacks_non_sink_batch_size2() + test_callbacks_non_sink() + test_callbacks_one_cb() + test_callbacks_non_sink_mismatch_size()