From c59349e096b8c2ce9b99a6a5bdffa2755a7b620e Mon Sep 17 00:00:00 2001 From: ms_yan Date: Fri, 19 Mar 2021 22:17:22 +0800 Subject: [PATCH] avoid double free for tdt channel --- .../ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc | 11 ++++++++--- .../ccsrc/minddata/dataset/engine/tdt/tdt_handle.h | 2 ++ .../ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc | 7 +++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc index 21f250073d7..0860e97584d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc @@ -26,14 +26,19 @@ void TdtHandle::AddHandle(acltdtChannelHandle *handle) { } bool TdtHandle::DestroyHandle() { - for (auto handle : acl_handle) { + bool destroy_all = true; + for (auto &handle : acl_handle) { if (handle != nullptr) { if (acltdtDestroyChannel(handle) != ACL_SUCCESS) { - return false; + destroy_all = false; + } else { + handle = nullptr; } } } - return true; + return destroy_all; } + +std::vector TdtHandle::GetHandle() { return acl_handle; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h index 3c0cfdf839c..5cabf8b0ec2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h @@ -28,6 +28,8 @@ class TdtHandle { static bool DestroyHandle(); + static std::vector GetHandle(); + private: TdtHandle() {} diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index 8db1082b7ee..940a185de2b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -33,8 +33,11 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) { } TdtPlugin::~TdtPlugin() { - if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { - MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; + std::vector total_handle = TdtHandle::GetHandle(); + if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) { + if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { + MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; + } } }