!39727 move tdt print thread from core to backend
Merge pull request !39727 from zhoufeng/delete-macro
This commit is contained in:
commit
6da7a9abd9
|
@ -14,19 +14,28 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/hal/device/tensorprint_utils.h"
|
||||
#include <atomic>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include "ir/tensor.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "mindspore/core/utils/file_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "proto/print.pb.h"
|
||||
#include "minddata/dataset/engine/device_queue_impl/tdt/tdt_handle.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
static std::map<aclDataType, TypeId> print_acl_data_type_map = {
|
||||
namespace mindspore::device::ascend {
|
||||
namespace {
|
||||
acltdtChannelHandle *g_acl_handle = nullptr;
|
||||
std::thread g_acl_tdt_print = {};
|
||||
|
||||
const std::map<aclDataType, TypeId> kPrintAclDataTypeMap = {
|
||||
{ACL_INT8, TypeId::kNumberTypeInt8}, {ACL_UINT8, TypeId::kNumberTypeUInt8},
|
||||
{ACL_INT16, TypeId::kNumberTypeInt16}, {ACL_UINT16, TypeId::kNumberTypeUInt16},
|
||||
{ACL_INT32, TypeId::kNumberTypeInt32}, {ACL_UINT32, TypeId::kNumberTypeUInt32},
|
||||
|
@ -34,19 +43,20 @@ static std::map<aclDataType, TypeId> print_acl_data_type_map = {
|
|||
{ACL_FLOAT16, TypeId::kNumberTypeFloat16}, {ACL_FLOAT, TypeId::kNumberTypeFloat32},
|
||||
{ACL_DOUBLE, TypeId::kNumberTypeFloat64}, {ACL_BOOL, TypeId::kNumberTypeBool}};
|
||||
|
||||
static std::map<aclDataType, size_t> acl_data_type_size_map = {
|
||||
const std::map<aclDataType, size_t> kAclDataTypeSizeMap = {
|
||||
{ACL_INT8, sizeof(int8_t)}, {ACL_UINT8, sizeof(uint8_t)}, {ACL_INT16, sizeof(int16_t)},
|
||||
{ACL_UINT16, sizeof(uint16_t)}, {ACL_INT32, sizeof(int32_t)}, {ACL_UINT32, sizeof(uint32_t)},
|
||||
{ACL_INT64, sizeof(int64_t)}, {ACL_UINT64, sizeof(uint64_t)}, {ACL_FLOAT16, sizeof(float) / 2},
|
||||
{ACL_FLOAT, sizeof(float)}, {ACL_DOUBLE, sizeof(double)}, {ACL_BOOL, sizeof(bool)}};
|
||||
|
||||
const std::map<aclDataType, std::string> kPrintTensorParseMap = {
|
||||
{ACL_INT8, "Int8"}, {ACL_UINT8, "UInt8"}, {ACL_INT16, "Int16"}, {ACL_UINT16, "UInt16"},
|
||||
{ACL_INT32, "Int32"}, {ACL_UINT32, "UInt32"}, {ACL_INT64, "Int64"}, {ACL_UINT64, "UInt64"},
|
||||
{ACL_FLOAT16, "Float16"}, {ACL_FLOAT, "Float32"}, {ACL_DOUBLE, "Float64"}, {ACL_BOOL, "Bool"}};
|
||||
|
||||
std::string GetParseType(const aclDataType &acl_data_type) {
|
||||
static const std::map<aclDataType, std::string> print_tensor_parse_map = {
|
||||
{ACL_INT8, "Int8"}, {ACL_UINT8, "UInt8"}, {ACL_INT16, "Int16"}, {ACL_UINT16, "UInt16"},
|
||||
{ACL_INT32, "Int32"}, {ACL_UINT32, "UInt32"}, {ACL_INT64, "Int64"}, {ACL_UINT64, "UInt64"},
|
||||
{ACL_FLOAT16, "Float16"}, {ACL_FLOAT, "Float32"}, {ACL_DOUBLE, "Float64"}, {ACL_BOOL, "Bool"}};
|
||||
auto type_iter = print_tensor_parse_map.find(acl_data_type);
|
||||
if (type_iter == print_tensor_parse_map.end()) {
|
||||
auto type_iter = kPrintTensorParseMap.find(acl_data_type);
|
||||
if (type_iter == kPrintTensorParseMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "type of tensor need to print is not support " << acl_data_type;
|
||||
}
|
||||
return type_iter->second;
|
||||
|
@ -100,7 +110,7 @@ void PrintScalarToBoolString(const char *str_data_ptr, const aclDataType &acl_da
|
|||
void ConvertDataItem2Scalar(const void *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
auto type_iter = print_acl_data_type_map.find(acl_data_type);
|
||||
auto type_iter = kPrintAclDataTypeMap.find(acl_data_type);
|
||||
auto type_id = type_iter->second;
|
||||
if (type_id == TypeId::kNumberTypeBool) {
|
||||
PrintScalarToBoolString(reinterpret_cast<const char *>(str_data_ptr), acl_data_type, buf);
|
||||
|
@ -132,8 +142,8 @@ void ConvertDataItem2Scalar(const void *str_data_ptr, const aclDataType &acl_dat
|
|||
}
|
||||
|
||||
bool judgeLengthValid(const size_t str_len, const aclDataType &acl_data_type) {
|
||||
auto type_iter = acl_data_type_size_map.find(acl_data_type);
|
||||
if (type_iter == acl_data_type_size_map.end()) {
|
||||
auto type_iter = kAclDataTypeSizeMap.find(acl_data_type);
|
||||
if (type_iter == kAclDataTypeSizeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
|
||||
}
|
||||
return str_len == type_iter->second;
|
||||
|
@ -163,14 +173,14 @@ bool ConvertDataset2Tensor(acltdtDataset *acl_dataset) {
|
|||
acl_data = const_cast<char *>(reinterpret_cast<std::string *>(acl_data)->c_str());
|
||||
MS_EXCEPTION_IF_NULL(acl_data);
|
||||
|
||||
ShapeVector tensorShape;
|
||||
tensorShape.resize(dim_num);
|
||||
ShapeVector tensor_shape;
|
||||
tensor_shape.resize(dim_num);
|
||||
|
||||
if (acltdtGetDimsFromItem(item, tensorShape.data(), dim_num) != ACL_SUCCESS) {
|
||||
if (acltdtGetDimsFromItem(item, tensor_shape.data(), dim_num) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "ACL failed to get dim-size from acl channel data";
|
||||
}
|
||||
|
||||
if ((tensorShape.size() == 1 && tensorShape[0] == 0) || tensorShape.size() == 0) {
|
||||
if ((tensor_shape.size() == 1 && tensor_shape[0] == 0) || tensor_shape.size() == 0) {
|
||||
if (!judgeLengthValid(acl_data_size, acl_data_type)) {
|
||||
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||
}
|
||||
|
@ -182,13 +192,13 @@ bool ConvertDataset2Tensor(acltdtDataset *acl_dataset) {
|
|||
std::string data(reinterpret_cast<const char *>(acl_data), acl_data_size);
|
||||
buf << data << std::endl;
|
||||
} else {
|
||||
auto type_iter = print_acl_data_type_map.find(acl_data_type);
|
||||
if (type_iter == print_acl_data_type_map.end()) {
|
||||
auto type_iter = kPrintAclDataTypeMap.find(acl_data_type);
|
||||
if (type_iter == kPrintAclDataTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "type of tensor need to print is not support " << GetParseType(acl_data_type);
|
||||
continue;
|
||||
}
|
||||
auto type_id = type_iter->second;
|
||||
mindspore::tensor::Tensor print_tensor(type_id, tensorShape);
|
||||
mindspore::tensor::Tensor print_tensor(type_id, tensor_shape);
|
||||
if (PrintTensorToString(acl_data, &print_tensor, acl_data_size)) {
|
||||
buf << print_tensor.ToStringNoLimit() << std::endl;
|
||||
}
|
||||
|
@ -227,14 +237,14 @@ bool SaveDataset2File(acltdtDataset *acl_dataset, const std::string &print_file_
|
|||
acl_data = const_cast<char *>(reinterpret_cast<std::string *>(acl_data)->c_str());
|
||||
MS_EXCEPTION_IF_NULL(acl_data);
|
||||
|
||||
ShapeVector tensorShape;
|
||||
tensorShape.resize(dim_num);
|
||||
ShapeVector tensor_shape;
|
||||
tensor_shape.resize(dim_num);
|
||||
|
||||
if (acltdtGetDimsFromItem(item, tensorShape.data(), dim_num) != ACL_SUCCESS) {
|
||||
if (acltdtGetDimsFromItem(item, tensor_shape.data(), dim_num) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "ACL failed to get dim-size from acl channel data";
|
||||
}
|
||||
|
||||
if ((tensorShape.size() == 1 && tensorShape[0] == 0) || tensorShape.size() == 0) {
|
||||
if ((tensor_shape.size() == 1 && tensor_shape[0] == 0) || tensor_shape.size() == 0) {
|
||||
if (!judgeLengthValid(acl_data_size, acl_data_type)) {
|
||||
MS_LOG(ERROR) << "Print op receive data length is invalid.";
|
||||
ret_end_thread = true;
|
||||
|
@ -247,8 +257,8 @@ bool SaveDataset2File(acltdtDataset *acl_dataset, const std::string &print_file_
|
|||
} else {
|
||||
auto parse_type = GetParseType(acl_data_type);
|
||||
prntpb::TensorProto *tensor = value->mutable_tensor();
|
||||
if (tensorShape.size() > 1 || (tensorShape.size() == 1 && tensorShape[0] != 1)) {
|
||||
for (const auto &dim : tensorShape) {
|
||||
if (tensor_shape.size() > 1 || (tensor_shape.size() == 1 && tensor_shape[0] != 1)) {
|
||||
for (const auto &dim : tensor_shape) {
|
||||
tensor->add_dims(static_cast<::google::protobuf::int64>(dim));
|
||||
}
|
||||
}
|
||||
|
@ -346,6 +356,18 @@ void TensorPrintOut2File(const acltdtChannelHandle *acl_handle, const std::strin
|
|||
ChangeFileMode(print_file_path, S_IRUSR);
|
||||
}
|
||||
|
||||
void JoinAclPrintThread(std::thread *thread) {
|
||||
try {
|
||||
if (thread->joinable()) {
|
||||
MS_LOG(INFO) << "join acl tdt host receive process";
|
||||
thread->join();
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TensorPrint::operator()() {
|
||||
if (print_file_path_ == "") {
|
||||
TensorPrintStdOut(acl_handle_);
|
||||
|
@ -353,4 +375,51 @@ void TensorPrint::operator()() {
|
|||
TensorPrintOut2File(acl_handle_, print_file_path_);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
void CreateTensorPrintThread(const PrintThreadCrt &ctr) {
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
return;
|
||||
}
|
||||
uint32_t device_id = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
std::string kReceivePrefix = "TF_RECEIVE_";
|
||||
std::string channel_name = "_npu_log";
|
||||
g_acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
|
||||
if (g_acl_handle == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
std::string print_file_path = MsContext::GetInstance()->get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
|
||||
g_acl_tdt_print = ctr(print_file_path, g_acl_handle);
|
||||
dataset::TdtHandle::AddHandle(&g_acl_handle, &g_acl_tdt_print);
|
||||
}
|
||||
|
||||
void DestroyTensorPrintThread() {
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
return;
|
||||
}
|
||||
// if TdtHandle::DestroyHandle called at taskmanager, all g_acl_handle will be set to nullptr;
|
||||
// but not joined the print thread, so add a protection to join the thread.
|
||||
if (g_acl_handle == nullptr) {
|
||||
MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
|
||||
JoinAclPrintThread(&g_acl_tdt_print);
|
||||
return;
|
||||
}
|
||||
aclError stop_status = acltdtStopChannel(g_acl_handle);
|
||||
if (stop_status != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed stop acl data channel and the stop_status is " << stop_status << std::endl;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
|
||||
JoinAclPrintThread(&g_acl_tdt_print);
|
||||
aclError destroyed_status = acltdtDestroyChannel(g_acl_handle);
|
||||
if (destroyed_status != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed destroy acl channel and the destroyed_status is " << destroyed_status << std::endl;
|
||||
return;
|
||||
}
|
||||
dataset::TdtHandle::DelHandle(&g_acl_handle);
|
||||
MS_LOG(INFO) << "Succeed destroy acl channel";
|
||||
}
|
||||
} // namespace mindspore::device::ascend
|
||||
|
|
|
@ -19,16 +19,15 @@
|
|||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ir/dtype/type.h"
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "tdt/tsd_client.h"
|
||||
#include "tdt/data_common.h"
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "proto/print.pb.h"
|
||||
#include "include/common/visible.h"
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
|
||||
namespace mindspore {
|
||||
class COMMON_EXPORT TensorPrint {
|
||||
extern "C" {
|
||||
struct acltdtChannelHandle;
|
||||
} // extern "C"
|
||||
|
||||
namespace mindspore::device::ascend {
|
||||
class TensorPrint {
|
||||
public:
|
||||
explicit TensorPrint(const std::string &path, const acltdtChannelHandle *acl_handle)
|
||||
: print_file_path_(path), acl_handle_(acl_handle) {}
|
||||
|
@ -39,5 +38,9 @@ class COMMON_EXPORT TensorPrint {
|
|||
std::string print_file_path_;
|
||||
const acltdtChannelHandle *acl_handle_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
using PrintThreadCrt = std::function<std::thread(std::string &, acltdtChannelHandle *)>;
|
||||
void CreateTensorPrintThread(const PrintThreadCrt &ctr);
|
||||
void DestroyTensorPrintThread();
|
||||
} // namespace mindspore::device::ascend
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORPRINT_UTILS_H_
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "plugin/device/ascend/hal/device/tensorprint_utils.h"
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "runtime/dev.h"
|
||||
#include "runtime/config.h"
|
||||
#include "toolchain/plog.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "plugin/device/ascend/hal/device/distribute/ascend_collective.h"
|
||||
|
@ -242,9 +243,11 @@ void AscendDeprecatedInterface::DumpProfileParallelStrategy(const FuncGraphPtr &
|
|||
}
|
||||
|
||||
bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ms_context_ptr);
|
||||
// set MS_CTX_ENABLE_GE_HETEROGENOUS true if ge heterogeneous mode
|
||||
int32_t is_heterogeneous = 0;
|
||||
(void)rtGetIsHeterogenous(&is_heterogeneous);
|
||||
ms_context_ptr->set_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS, is_heterogeneous == 1);
|
||||
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
|
||||
return true;
|
||||
|
@ -298,15 +301,13 @@ bool AscendDeprecatedInterface::OpenTsd(const std::shared_ptr<MsContext> &ms_con
|
|||
auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
|
||||
return std::thread(TensorPrint(path, acl_handle));
|
||||
};
|
||||
ms_context_ptr->CreateTensorPrintThread(thread_crt);
|
||||
CreateTensorPrintThread(thread_crt);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ms_context_ptr);
|
||||
MS_LOG(INFO) << "Start to close tsd, ref = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
return true;
|
||||
|
@ -316,7 +317,7 @@ bool AscendDeprecatedInterface::CloseTsd(const std::shared_ptr<MsContext> &ms_co
|
|||
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
pybind11::gil_scoped_release gil_release;
|
||||
ms_context_ptr->DestroyTensorPrintThread();
|
||||
DestroyTensorPrintThread();
|
||||
#endif
|
||||
if (ErrorManager::GetInstance().Init() != 0) {
|
||||
MS_LOG(WARNING) << "Init ascend error manager failed, some ascend error log may be left out.";
|
||||
|
|
|
@ -127,10 +127,6 @@ bool MsContext::set_backend_policy(const std::string &policy) {
|
|||
auto enable_ge = mindspore::common::GetEnv("MS_ENABLE_GE");
|
||||
if (enable_ge == "1") {
|
||||
policy_new = "ge";
|
||||
// set MS_CTX_ENABLE_GE_HETEROGENOUS true if ge heterogeneous mode
|
||||
int32_t is_heterogeneous = 0;
|
||||
(void)rtGetIsHeterogenous(&is_heterogeneous);
|
||||
set_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS, is_heterogeneous == 1);
|
||||
}
|
||||
#endif
|
||||
if (policy_map_.find(policy_new) == policy_map_.end()) {
|
||||
|
@ -142,64 +138,6 @@ bool MsContext::set_backend_policy(const std::string &policy) {
|
|||
return true;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
void MsContext::CreateTensorPrintThread(const PrintThreadCrt &ctr) {
|
||||
if (get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
return;
|
||||
}
|
||||
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
std::string kReceivePrefix = "TF_RECEIVE_";
|
||||
std::string channel_name = "_npu_log";
|
||||
acl_handle_ = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
|
||||
if (acl_handle_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
std::string print_file_path = get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
|
||||
acl_tdt_print_ = ctr(print_file_path, acl_handle_);
|
||||
TdtHandle::AddHandle(&acl_handle_, &acl_tdt_print_);
|
||||
}
|
||||
|
||||
static void JoinAclPrintThread(std::thread *thread) {
|
||||
try {
|
||||
if (thread->joinable()) {
|
||||
MS_LOG(INFO) << "join acl tdt host receive process";
|
||||
thread->join();
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void MsContext::DestroyTensorPrintThread() {
|
||||
if (get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
|
||||
return;
|
||||
}
|
||||
// if TdtHandle::DestroyHandle called at taskmanger, all acl_handle_ will be set to nullptr;
|
||||
// but not joined the print thread, so add a protect to join the thread.
|
||||
if (acl_handle_ == nullptr) {
|
||||
MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
|
||||
JoinAclPrintThread(&acl_tdt_print_);
|
||||
return;
|
||||
}
|
||||
aclError stopStatus = acltdtStopChannel(acl_handle_);
|
||||
if (stopStatus != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
|
||||
JoinAclPrintThread(&acl_tdt_print_);
|
||||
aclError destroydStatus = acltdtDestroyChannel(acl_handle_);
|
||||
if (destroydStatus != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl;
|
||||
return;
|
||||
}
|
||||
TdtHandle::DelHandle(&acl_handle_);
|
||||
MS_LOG(INFO) << "Succeed destroy acl channel";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
std::string MsContext::backend_policy() const {
|
||||
auto res = std::find_if(
|
||||
policy_map_.begin(), policy_map_.end(),
|
||||
|
|
|
@ -24,15 +24,6 @@
|
|||
#include <functional>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "minddata/dataset/engine/device_queue_impl/tdt/tdt_handle.h"
|
||||
using mindspore::dataset::TdtHandle;
|
||||
#endif
|
||||
#ifndef NO_DLIB
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "runtime/dev.h"
|
||||
#include "runtime/config.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
enum MsBackendPolicy {
|
||||
|
@ -154,11 +145,7 @@ class MS_CORE_API MsContext {
|
|||
bool enable_dump_ir() const;
|
||||
std::string backend_policy() const;
|
||||
bool set_backend_policy(const std::string &policy);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
using PrintThreadCrt = std::function<std::thread(std::string &, acltdtChannelHandle *)>;
|
||||
void CreateTensorPrintThread(const PrintThreadCrt &ctr);
|
||||
void DestroyTensorPrintThread();
|
||||
#endif
|
||||
|
||||
static void device_seter(const DeviceSeter &device) { seter_ = device; }
|
||||
static void device_type_seter(const DeviceTypeSeter &device_type) { device_type_seter_ = device_type; }
|
||||
|
||||
|
@ -194,10 +181,6 @@ class MS_CORE_API MsContext {
|
|||
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
|
||||
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
|
||||
MsBackendPolicy backend_policy_;
|
||||
#ifdef ENABLE_TDTQUE
|
||||
acltdtChannelHandle *acl_handle_ = nullptr;
|
||||
std::thread acl_tdt_print_;
|
||||
#endif
|
||||
};
|
||||
|
||||
// set method implementation for type bool/int/uint32_t/float/std::string
|
||||
|
|
Loading…
Reference in New Issue