diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc index 5ab0011561e..234d920fb7c 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.cc @@ -26,7 +26,7 @@ void CallbackManager::AddCallbacks(std::vector> call callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end()); } -Status CallbackManager::Init(std::shared_ptr op) { +Status CallbackManager::Init(DatasetOp *op) { RETURN_UNEXPECTED_IF_NULL(op); op_ = op; // turn the flag on if callback is set @@ -42,6 +42,7 @@ Status CallbackManager::Init(std::shared_ptr op) { Status CallbackManager::Begin(const CallbackParam &cb_param) { RETURN_OK_IF_TRUE(!enabled_); + RETURN_UNEXPECTED_IF_NULL(op_); std::vector callback_inds; // go through all callback functions to see if each function is needed for (size_t ind = 0; ind < callbacks_.size(); ind++) { @@ -61,6 +62,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) { Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { RETURN_OK_IF_TRUE(!enabled_); + RETURN_UNEXPECTED_IF_NULL(op_); std::vector callback_inds; // go through all callback functions to see if each function is needed for (size_t ind = 0; ind < callbacks_.size(); ind++) { @@ -80,6 +82,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { Status CallbackManager::StepBegin(const CallbackParam &cb_param) { RETURN_OK_IF_TRUE(!enabled_); + RETURN_UNEXPECTED_IF_NULL(op_); std::vector callback_inds; // go through all callback functions to see if each function is needed for (size_t ind = 0; ind < callbacks_.size(); ind++) { @@ -100,6 +103,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) { Status CallbackManager::End(const CallbackParam &cb_param) { RETURN_OK_IF_TRUE(!enabled_); + RETURN_UNEXPECTED_IF_NULL(op_); std::vector callback_inds; // go through all callback functions to see if each function is needed for (size_t ind = 0; ind < callbacks_.size(); ind++) { @@ -119,6 +123,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) { Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { RETURN_OK_IF_TRUE(!enabled_); + RETURN_UNEXPECTED_IF_NULL(op_); std::vector callback_inds; // go through all callback functions to see if each function is needed for (size_t ind = 0; ind < callbacks_.size(); ind++) { @@ -138,6 +143,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { Status CallbackManager::StepEnd(const CallbackParam &cb_param) { RETURN_OK_IF_TRUE(!enabled_); + RETURN_UNEXPECTED_IF_NULL(op_); std::vector callback_inds; // go through all callback functions to see if each function is needed for (size_t ind = 0; ind < callbacks_.size(); ind++) { diff --git a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h index c0e65126a0c..0fedb97115c 100644 --- a/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h +++ b/mindspore/ccsrc/minddata/dataset/callback/callback_manager.h @@ -44,7 +44,7 @@ class CallbackManager { /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads /// \return Status - Status Init(std::shared_ptr op); + Status Init(DatasetOp *op); /// \brief callback function called at the start of the first row /// \return Status @@ -70,11 +70,9 @@ class CallbackManager { /// \return Status Status StepEnd(const CallbackParam &); - bool HasCallback() { return !callbacks_.empty(); } - private: - bool enabled_; // flag to enable callback, if false, all functions would return immediately - std::shared_ptr op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager + bool enabled_; // flag to enable callback, if false, all functions would return immediately + DatasetOp *op_; // back pointer to DatasetOp, raw pointer to avoid circular ownership std::vector> callbacks_; // list of callbacks the DatasetOp needs to call }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc index eff69d76098..08f6bb93855 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc @@ -164,9 +164,7 @@ Status MapOp::operator()() { // Create and register the local queues. local_queues_.Init(num_workers_, oc_queue_size_); // init callback - if (callback_manager_.HasCallback()) { - RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); - } + RETURN_IF_NOT_OK(callback_manager_.Init(this)); Status rc = local_queues_.Register(tree_->AllTasks()); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); if (rc.IsError()) { @@ -181,26 +179,23 @@ Status MapOp::operator()() { RETURN_IF_NOT_OK(rc); // num_buffers received, including eoe, num_epoch, num_step of current epoch int64_t num_buf = 0, ep_step = 0, total_step = 0; - if (callback_manager_.HasCallback()) { - RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); - } + + RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); std::unique_ptr buff; RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); while (!buff->eof()) { if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { - if (callback_manager_.HasCallback()) { - RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); - } + RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); } while (!buff->eoe()) { ep_step++; total_step++; // Create an empty map worker job to be populated by a databuffer and map jobs - if (callback_manager_.HasCallback()) { - RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); - } + + RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + std::unique_ptr worker_job = std::make_unique(std::move(buff)); // Populate map worker job for a worker to execute @@ -208,18 +203,16 @@ Status MapOp::operator()() { // Push map worker job to the corresponding worker's queue RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); - if (callback_manager_.HasCallback()) { - RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); - } + + RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); } // check whether this is the end of a real epoch (not all eoe signals end of epoch) if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { - if (callback_manager_.HasCallback()) { - RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); - } + RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); + ep_step = 0; } // Propagate the eoe buffer to worker