forked from mindspore-Ecosystem/mindspore
!18525 In pynative mode, after train complete, occur core dump.
Merge pull request !18525 from zhangzhaoju/master_tensor_print
This commit is contained in:
commit
2d8ba753fe
|
@ -16,26 +16,28 @@
|
|||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
|
||||
namespace mindspore {
|
||||
extern std::set<void **> acl_handle_set;
|
||||
extern std::map<void **, std::thread *> acl_handle_map;
|
||||
namespace dataset {
|
||||
|
||||
void TdtHandle::AddHandle(acltdtChannelHandle **handle) {
|
||||
void TdtHandle::AddHandle(acltdtChannelHandle **handle, std::thread *use_thread) {
|
||||
if (*handle != nullptr) {
|
||||
acl_handle_set.insert(reinterpret_cast<void **>(handle));
|
||||
acl_handle_map.insert({reinterpret_cast<void **>(handle), use_thread});
|
||||
}
|
||||
}
|
||||
|
||||
void TdtHandle::DelHandle(acltdtChannelHandle **handle) {
|
||||
void **void_handle = reinterpret_cast<void **>(handle);
|
||||
acl_handle_set.erase(void_handle);
|
||||
acl_handle_map.erase(void_handle);
|
||||
}
|
||||
|
||||
bool TdtHandle::DestroyHandle() {
|
||||
bool destroy_all = true;
|
||||
for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) {
|
||||
acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it);
|
||||
for (auto &item : acl_handle_map) {
|
||||
acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(item.first);
|
||||
if (*handle != nullptr) {
|
||||
acltdtStopChannel(*handle);
|
||||
if (item.second != nullptr && item.second->joinable()) {
|
||||
item.second->join();
|
||||
}
|
||||
if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) {
|
||||
destroy_all = false;
|
||||
} else {
|
||||
|
|
|
@ -17,14 +17,15 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <thread>
|
||||
#include "acl/acl_tdt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class TdtHandle {
|
||||
public:
|
||||
static void AddHandle(acltdtChannelHandle **handle);
|
||||
static void AddHandle(acltdtChannelHandle **handle, std::thread *use_thread);
|
||||
|
||||
static bool DestroyHandle();
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) {
|
|||
if (acl_handle_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
|
||||
}
|
||||
TdtHandle::AddHandle(&acl_handle_);
|
||||
TdtHandle::AddHandle(&acl_handle_, nullptr);
|
||||
}
|
||||
|
||||
TdtPlugin::~TdtPlugin() {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::set<void **> acl_handle_set = std::set<void **>();
|
||||
std::map<void **, std::thread *> acl_handle_map;
|
||||
// set default log level to WARNING for all sub modules
|
||||
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
#include "utils/overload.h"
|
||||
#include "./securec.h"
|
||||
|
@ -42,7 +43,7 @@ static constexpr size_t GetRelPathPos() noexcept {
|
|||
}
|
||||
|
||||
namespace mindspore {
|
||||
extern std::set<void **> acl_handle_set __attribute__((visibility("default")));
|
||||
extern std::map<void **, std::thread *> acl_handle_map __attribute__((visibility("default")));
|
||||
#define FILE_NAME \
|
||||
(sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \
|
||||
: static_cast<const char *>(__FILE__))
|
||||
|
|
|
@ -116,14 +116,13 @@ void MsContext::CreateTensorPrintThread(PrintThreadCrt ctr) {
|
|||
std::string kReceivePrefix = "TF_RECEIVE_";
|
||||
std::string channel_name = "_npu_log";
|
||||
acl_handle_ = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
|
||||
if (acl_handle_ != nullptr) {
|
||||
MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
TdtHandle::AddHandle(&acl_handle_);
|
||||
} else {
|
||||
if (acl_handle_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
std::string print_file_path = get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
|
||||
acl_tdt_print_ = ctr(print_file_path, acl_handle_);
|
||||
TdtHandle::AddHandle(&acl_handle_, &acl_tdt_print_);
|
||||
}
|
||||
|
||||
static void JoinAclPrintThread(std::thread *thread) {
|
||||
|
|
Loading…
Reference in New Issue