!18525 In pynative mode, after train complete, occur core dump.

Merge pull request !18525 from zhangzhaoju/master_tensor_print
This commit is contained in:
i-robot 2021-06-21 19:08:18 +08:00 committed by Gitee
commit 2d8ba753fe
7 changed files with 21 additions and 17 deletions

View File

@ -16,26 +16,28 @@
#include "minddata/dataset/engine/tdt/tdt_handle.h"
namespace mindspore {
extern std::set<void **> acl_handle_set;
extern std::map<void **, std::thread *> acl_handle_map;
namespace dataset {
void TdtHandle::AddHandle(acltdtChannelHandle **handle) {
void TdtHandle::AddHandle(acltdtChannelHandle **handle, std::thread *use_thread) {
if (*handle != nullptr) {
acl_handle_set.insert(reinterpret_cast<void **>(handle));
acl_handle_map.insert({reinterpret_cast<void **>(handle), use_thread});
}
}
void TdtHandle::DelHandle(acltdtChannelHandle **handle) {
void **void_handle = reinterpret_cast<void **>(handle);
acl_handle_set.erase(void_handle);
acl_handle_map.erase(void_handle);
}
bool TdtHandle::DestroyHandle() {
bool destroy_all = true;
for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) {
acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it);
for (auto &item : acl_handle_map) {
acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(item.first);
if (*handle != nullptr) {
acltdtStopChannel(*handle);
if (item.second != nullptr && item.second->joinable()) {
item.second->join();
}
if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) {
destroy_all = false;
} else {

View File

@ -17,14 +17,15 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
#include <iostream>
#include <set>
#include <map>
#include <thread>
#include "acl/acl_tdt.h"
namespace mindspore {
namespace dataset {
class TdtHandle {
public:
static void AddHandle(acltdtChannelHandle **handle);
static void AddHandle(acltdtChannelHandle **handle, std::thread *use_thread);
static bool DestroyHandle();

View File

@ -29,7 +29,7 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) {
if (acl_handle_ == nullptr) {
MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
}
TdtHandle::AddHandle(&acl_handle_);
TdtHandle::AddHandle(&acl_handle_, nullptr);
}
TdtPlugin::~TdtPlugin() {

View File

@ -22,6 +22,7 @@
#include <memory>
#include <string>
#include <vector>
#include <thread>
#include "acl/acl_tdt.h"
#include "minddata/dataset/engine/tdt/tdt_handle.h"

View File

@ -17,7 +17,7 @@
#include "utils/log_adapter.h"
namespace mindspore {
std::set<void **> acl_handle_set = std::set<void **>();
std::map<void **, std::thread *> acl_handle_map;
// set default log level to WARNING for all sub modules
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
} // namespace mindspore

View File

@ -22,7 +22,8 @@
#include <string>
#include <sstream>
#include <memory>
#include <set>
#include <map>
#include <thread>
#include <functional>
#include "utils/overload.h"
#include "./securec.h"
@ -42,7 +43,7 @@ static constexpr size_t GetRelPathPos() noexcept {
}
namespace mindspore {
extern std::set<void **> acl_handle_set __attribute__((visibility("default")));
extern std::map<void **, std::thread *> acl_handle_map __attribute__((visibility("default")));
#define FILE_NAME \
(sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \
: static_cast<const char *>(__FILE__))

View File

@ -116,14 +116,13 @@ void MsContext::CreateTensorPrintThread(PrintThreadCrt ctr) {
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
acl_handle_ = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle_ != nullptr) {
MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
TdtHandle::AddHandle(&acl_handle_);
} else {
if (acl_handle_ == nullptr) {
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
}
MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
std::string print_file_path = get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
acl_tdt_print_ = ctr(print_file_path, acl_handle_);
TdtHandle::AddHandle(&acl_handle_, &acl_tdt_print_);
}
static void JoinAclPrintThread(std::thread *thread) {