forked from mindspore-Ecosystem/mindspore
!6223 Fix Memory leak in callback_manager
Merge pull request !6223 from ZiruiWu/fix_memory_leak_in_callback
This commit is contained in:
commit
b4d527e198
|
@ -26,7 +26,7 @@ void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> call
|
|||
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||
}
|
||||
|
||||
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
|
||||
Status CallbackManager::Init(DatasetOp *op) {
|
||||
RETURN_UNEXPECTED_IF_NULL(op);
|
||||
op_ = op;
|
||||
// turn the flag on if callback is set
|
||||
|
@ -42,6 +42,7 @@ Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
|
|||
|
||||
Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
RETURN_UNEXPECTED_IF_NULL(op_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
|
@ -61,6 +62,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
|||
|
||||
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
RETURN_UNEXPECTED_IF_NULL(op_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
|
@ -80,6 +82,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
|||
|
||||
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
RETURN_UNEXPECTED_IF_NULL(op_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
|
@ -100,6 +103,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
|||
|
||||
Status CallbackManager::End(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
RETURN_UNEXPECTED_IF_NULL(op_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
|
@ -119,6 +123,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) {
|
|||
|
||||
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
RETURN_UNEXPECTED_IF_NULL(op_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
|
@ -138,6 +143,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
|||
|
||||
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
RETURN_UNEXPECTED_IF_NULL(op_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
|
|
|
@ -44,7 +44,7 @@ class CallbackManager {
|
|||
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
|
||||
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
|
||||
/// \return Status
|
||||
Status Init(std::shared_ptr<DatasetOp> op);
|
||||
Status Init(DatasetOp *op);
|
||||
|
||||
/// \brief callback function called at the start of the first row
|
||||
/// \return Status
|
||||
|
@ -70,11 +70,9 @@ class CallbackManager {
|
|||
/// \return Status
|
||||
Status StepEnd(const CallbackParam &);
|
||||
|
||||
bool HasCallback() { return !callbacks_.empty(); }
|
||||
|
||||
private:
|
||||
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
|
||||
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||
DatasetOp *op_; // back pointer to DatasetOp, raw pointer to avoid circular ownership
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -164,9 +164,7 @@ Status MapOp::operator()() {
|
|||
// Create and register the local queues.
|
||||
local_queues_.Init(num_workers_, oc_queue_size_);
|
||||
// init callback
|
||||
if (callback_manager_.HasCallback()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
|
||||
}
|
||||
RETURN_IF_NOT_OK(callback_manager_.Init(this));
|
||||
Status rc = local_queues_.Register(tree_->AllTasks());
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
if (rc.IsError()) {
|
||||
|
@ -181,26 +179,23 @@ Status MapOp::operator()() {
|
|||
RETURN_IF_NOT_OK(rc);
|
||||
// num_buffers received, including eoe, num_epoch, num_step of current epoch
|
||||
int64_t num_buf = 0, ep_step = 0, total_step = 0;
|
||||
if (callback_manager_.HasCallback()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
|
||||
|
||||
std::unique_ptr<DataBuffer> buff;
|
||||
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
while (!buff->eof()) {
|
||||
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
|
||||
if (callback_manager_.HasCallback()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
while (!buff->eoe()) {
|
||||
ep_step++;
|
||||
total_step++;
|
||||
// Create an empty map worker job to be populated by a databuffer and map jobs
|
||||
if (callback_manager_.HasCallback()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||
|
||||
// Populate map worker job for a worker to execute
|
||||
|
@ -208,18 +203,16 @@ Status MapOp::operator()() {
|
|||
|
||||
// Push map worker job to the corresponding worker's queue
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||
if (callback_manager_.HasCallback()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
}
|
||||
|
||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
|
||||
if (callback_manager_.HasCallback()) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
|
||||
ep_step = 0;
|
||||
}
|
||||
// Propagate the eoe buffer to worker
|
||||
|
|
Loading…
Reference in New Issue