forked from mindspore-Ecosystem/mindspore
!12202 update tdt interface with acltdt
From: @ms_yan Reviewed-by: @kingxian,@heleiwang Signed-off-by: @kingxian
This commit is contained in:
commit
6f9836f085
|
@ -277,7 +277,7 @@ if(ENABLE_GPUQUE)
|
|||
endif()
|
||||
|
||||
if(ENABLE_TDTQUE)
|
||||
target_link_libraries(_c_dataengine PRIVATE ${TSDCLIENT})
|
||||
target_link_libraries(_c_dataengine PRIVATE ${ACL})
|
||||
endif()
|
||||
|
||||
add_dependencies(_c_dataengine _c_mindrecord)
|
||||
|
|
|
@ -181,7 +181,7 @@ std::shared_ptr<PullIterator> Dataset::CreatePullBasedIterator(std::vector<std::
|
|||
#ifndef ENABLE_ANDROID
|
||||
// Function to return a transferred Node that transfers data through a device.
|
||||
bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type,
|
||||
int32_t num_epochs, bool send_epoch_end, int32_t total_batches,
|
||||
int32_t device_id, int32_t num_epochs, bool send_epoch_end, int32_t total_batches,
|
||||
bool create_data_info_queue) {
|
||||
Status rc;
|
||||
|
||||
|
@ -196,7 +196,7 @@ bool Dataset::DeviceQueueCharIF(const std::vector<char> &queue_name, const std::
|
|||
// Add TransferNode IR on top of dataset
|
||||
auto ds =
|
||||
std::make_shared<TransferNode>(shared_from_this()->IRNode(), CharToString(queue_name), CharToString(device_type),
|
||||
send_epoch_end, total_batches, create_data_info_queue);
|
||||
device_id, send_epoch_end, total_batches, create_data_info_queue);
|
||||
|
||||
// Get ToDevice consumer
|
||||
auto consumer = std::make_unique<ToDevice>(num_epochs);
|
||||
|
|
|
@ -274,9 +274,10 @@ PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
|
|||
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
|
||||
"to create a TransferNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) {
|
||||
auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end,
|
||||
total_batch, create_data_info_queue);
|
||||
int32_t device_id, bool send_epoch_end, int32_t total_batch,
|
||||
bool create_data_info_queue) {
|
||||
auto transfer = std::make_shared<TransferNode>(
|
||||
self, queue_name, device_type, device_id, send_epoch_end, total_batch, create_data_info_queue);
|
||||
THROW_IF_ERROR(transfer->ValidateParams());
|
||||
return transfer;
|
||||
}));
|
||||
|
|
|
@ -57,6 +57,7 @@ DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, i
|
|||
#endif
|
||||
#ifdef ENABLE_TDTQUE
|
||||
ascend_keep_waiting_ = true;
|
||||
tdtInstancePtr = std::make_shared<TdtPlugin>(channel_name_, device_id_);
|
||||
#endif
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
md_channel_info_ = std::make_shared<MDChannelInfo>(channel_name_);
|
||||
|
@ -200,9 +201,9 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
}
|
||||
if (current_buffer->eoe() && send_epoch_end_) {
|
||||
TensorRow currRow;
|
||||
auto status =
|
||||
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
|
||||
if (status == TdtStatus::FAILED) {
|
||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost,
|
||||
ACL_TENSOR_DATA_END_OF_SEQUENCE);
|
||||
if (status != Status::OK()) {
|
||||
if (stop_send_) {
|
||||
send_finished_ = true;
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
|
@ -238,7 +239,7 @@ void DeviceQueueOp::WaitContinueSignal() const {
|
|||
|
||||
Status DeviceQueueOp::SendRowToTdt(TensorRow currRow, bool isProfilingEnable, int32_t *tdt_cost) {
|
||||
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, *tdt_cost);
|
||||
if (status == TdtStatus::FAILED) {
|
||||
if (status != Status::OK()) {
|
||||
if (stop_send_) {
|
||||
MS_LOG(INFO) << "stop_send received";
|
||||
return Status::OK();
|
||||
|
|
|
@ -32,20 +32,20 @@ namespace dataset {
|
|||
|
||||
// Constructor for TransferNode
|
||||
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
|
||||
int32_t device_id, bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
|
||||
: prefetch_size_(16),
|
||||
queue_name_(std::move(queue_name)),
|
||||
device_type_(std::move(device_type)),
|
||||
send_epoch_end_(send_epoch_end),
|
||||
total_batch_(total_batch),
|
||||
create_data_info_queue_(create_data_info_queue),
|
||||
device_id_(0) {
|
||||
device_id_(device_id) {
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> TransferNode::Copy() {
|
||||
auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_,
|
||||
create_data_info_queue_);
|
||||
auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, device_id_, send_epoch_end_,
|
||||
total_batch_, create_data_info_queue_);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
@ -104,10 +104,6 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Get device ID (shard ID) from children
|
||||
device_id_ = 0;
|
||||
RETURN_IF_NOT_OK(this->GetShardId(&device_id_));
|
||||
|
||||
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||
total_batch_, create_data_info_queue_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
|
|
|
@ -29,8 +29,8 @@ namespace dataset {
|
|||
class TransferNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end,
|
||||
int32_t total_batch, bool create_data_info_queue);
|
||||
TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, int32_t device_id,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue);
|
||||
|
||||
/// \brief Destructor
|
||||
~TransferNode() = default;
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-tdt OBJECT
|
||||
tdt_plugin.cc
|
||||
)
|
||||
add_library(engine-tdt OBJECT tdt_plugin.cc tdt_handle.cc)
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>();
|
||||
|
||||
void TdtHandle::AddHandle(acltdtChannelHandle *handle) {
|
||||
if (handle != nullptr) {
|
||||
acl_handle.emplace_back(handle);
|
||||
}
|
||||
}
|
||||
|
||||
bool TdtHandle::DestroyHandle() {
|
||||
for (auto handle : acl_handle) {
|
||||
if (handle != nullptr) {
|
||||
if (acltdtDestroyChannel(handle) != ACL_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "acl/acl_tdt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class TdtHandle {
|
||||
public:
|
||||
static void AddHandle(acltdtChannelHandle *handle);
|
||||
|
||||
static bool DestroyHandle();
|
||||
|
||||
private:
|
||||
TdtHandle() {}
|
||||
|
||||
static std::vector<acltdtChannelHandle *> acl_handle;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_
|
|
@ -23,109 +23,142 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr;
|
||||
|
||||
std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
|
||||
if (instance_ptr_ == nullptr) {
|
||||
instance_ptr_ = std::shared_ptr<TdtPlugin>(new TdtPlugin);
|
||||
TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) {
|
||||
// create acl tdt handle
|
||||
acl_handle_ = acltdtCreateChannel(device_id, channel_name.c_str());
|
||||
if (acl_handle_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
|
||||
}
|
||||
return instance_ptr_;
|
||||
TdtHandle::AddHandle(acl_handle_);
|
||||
}
|
||||
|
||||
TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
|
||||
tdt::TdtDataType tdt_type) {
|
||||
MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
|
||||
std::vector<DataItem> items;
|
||||
double start_time;
|
||||
if (tdt_type == tdt::TDT_TENSOR) {
|
||||
auto ret = translate(ts_row, items);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "TDT converting tensor failed!";
|
||||
return FAILED;
|
||||
}
|
||||
} else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) {
|
||||
DataItem data_item;
|
||||
data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE;
|
||||
items.emplace_back(data_item);
|
||||
MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE";
|
||||
TdtPlugin::~TdtPlugin() {
|
||||
if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed to destroy channel for tdt queue.";
|
||||
}
|
||||
}
|
||||
|
||||
Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
|
||||
acltdtTensorType tdt_type) {
|
||||
MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
|
||||
|
||||
acltdtDataset *acl_dataset = nullptr;
|
||||
double start_time;
|
||||
auto ret = translate(tdt_type, ts_row, &acl_dataset);
|
||||
if (ret != Status::OK()) {
|
||||
DestroyAclDataset(acl_dataset);
|
||||
RETURN_STATUS_UNEXPECTED("Converting into TDT tensor failed!");
|
||||
}
|
||||
|
||||
if (profiling) {
|
||||
start_time = ProfilingTime::GetCurMilliSecond();
|
||||
}
|
||||
#if ENABLE_D
|
||||
// Data prefetch only when PS mode enables cache.
|
||||
if (items.size() > 0) {
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_,
|
||||
items[0].tensorType_)) {
|
||||
return FAILED;
|
||||
if (acltdtGetDatasetSize(acl_dataset) > 0) {
|
||||
acltdtDataItem *item0 = acltdtGetDataItem(acl_dataset, 0);
|
||||
std::string item_type = "unsupported";
|
||||
if (acltdtGetDataTypeFromItem(item0) == ACL_INT32) {
|
||||
item_type = "int32";
|
||||
}
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, acltdtGetDataAddrFromItem(item0),
|
||||
acltdtGetDataSizeFromItem(item0), item_type)) {
|
||||
RETURN_STATUS_UNEXPECTED("PrefetchData failed in when pre-processing sending data.");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (tdt::TdtHostPushData(channel_name, items) != 0) {
|
||||
return FAILED;
|
||||
auto status = acltdtSendTensor(acl_handle_, acl_dataset, -1);
|
||||
DestroyAclDataset(acl_dataset);
|
||||
if (status != ACL_SUCCESS) {
|
||||
RETURN_STATUS_UNEXPECTED("Tdt Send data failed.");
|
||||
}
|
||||
if (profiling) {
|
||||
double end_time = ProfilingTime::GetCurMilliSecond();
|
||||
time = (int32_t)(end_time - start_time);
|
||||
}
|
||||
return SUCCESS;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) {
|
||||
Status TdtPlugin::getTdtType(DataType d_type, aclDataType &datatype) {
|
||||
switch (d_type.value()) {
|
||||
case DataType::DE_BOOL:
|
||||
datatype = "bool";
|
||||
datatype = ACL_BOOL;
|
||||
break;
|
||||
case DataType::DE_INT8:
|
||||
datatype = "int8";
|
||||
datatype = ACL_INT8;
|
||||
break;
|
||||
case DataType::DE_UINT8:
|
||||
datatype = "uint8";
|
||||
datatype = ACL_UINT8;
|
||||
break;
|
||||
case DataType::DE_INT16:
|
||||
datatype = "int16";
|
||||
datatype = ACL_INT16;
|
||||
break;
|
||||
case DataType::DE_UINT16:
|
||||
datatype = "uint16";
|
||||
datatype = ACL_UINT16;
|
||||
break;
|
||||
case DataType::DE_INT32:
|
||||
datatype = "int32";
|
||||
datatype = ACL_INT32;
|
||||
break;
|
||||
case DataType::DE_UINT32:
|
||||
datatype = "uint32";
|
||||
datatype = ACL_UINT32;
|
||||
break;
|
||||
case DataType::DE_FLOAT16:
|
||||
datatype = "float16";
|
||||
datatype = ACL_FLOAT16;
|
||||
break;
|
||||
case DataType::DE_FLOAT32:
|
||||
datatype = "float32";
|
||||
datatype = ACL_FLOAT;
|
||||
break;
|
||||
case DataType::DE_FLOAT64:
|
||||
datatype = "float64";
|
||||
datatype = ACL_DOUBLE;
|
||||
break;
|
||||
case DataType::DE_INT64:
|
||||
datatype = "int64";
|
||||
datatype = ACL_INT64;
|
||||
break;
|
||||
case DataType::DE_UINT64:
|
||||
datatype = "uint64";
|
||||
datatype = ACL_UINT64;
|
||||
break;
|
||||
default:
|
||||
return FAILED;
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, got unexpected data type.");
|
||||
}
|
||||
return SUCCESS;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &items) {
|
||||
if (ts_row.size() == 0) {
|
||||
MS_LOG(ERROR) << "TDT the size of row is zero.";
|
||||
return SUCCESS;
|
||||
Status TdtPlugin::translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset) {
|
||||
auto acl_dataset = acltdtCreateDataset();
|
||||
if (acl_dataset == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Create tdt dataset failed.");
|
||||
}
|
||||
for (auto ts : ts_row) {
|
||||
std::string datatype;
|
||||
TdtStatus status = getTdtType(ts->type(), datatype);
|
||||
if (status != SUCCESS) {
|
||||
return status;
|
||||
auto status = AssembleTensor2AclDataset(tdt_type, ts_row, acl_dataset);
|
||||
if (status != Status::OK()) {
|
||||
DestroyAclDataset(acl_dataset);
|
||||
RETURN_STATUS_UNEXPECTED("Assemble tensor row to tdt dataset failed.");
|
||||
}
|
||||
|
||||
*output_acl_dataset = acl_dataset;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TdtPlugin::AssembleTensor2AclDataset(acltdtTensorType tdt_type, const TensorRow &ts_row,
|
||||
acltdtDataset *acl_dataset) {
|
||||
if (tdt_type != ACL_TENSOR_DATA_TENSOR || ts_row.size() == 0) {
|
||||
acltdtDataItem *acl_data = acltdtCreateDataItem(tdt_type, nullptr, 0, ACL_BOOL, nullptr, 0);
|
||||
if (acl_data == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Create data item failed when send data with type:" + std::to_string(tdt_type));
|
||||
}
|
||||
if (acltdtAddDataItem(acl_dataset, acl_data) != ACL_SUCCESS) {
|
||||
if (acltdtDestroyDataItem(acl_data) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destroy data item failed when send data with type: " << tdt_type;
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Add data item to tdt dataset failed when send data.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (auto ts : ts_row) {
|
||||
aclDataType datatype;
|
||||
acltdtDataItem *acl_data = nullptr;
|
||||
RETURN_IF_NOT_OK(getTdtType(ts->type(), datatype));
|
||||
|
||||
TensorShape tsShape = ts->shape();
|
||||
std::string dataShapes = "[";
|
||||
for (auto dim : tsShape.AsVector()) {
|
||||
|
@ -133,18 +166,46 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i
|
|||
}
|
||||
dataShapes.pop_back();
|
||||
(void)dataShapes.append("]");
|
||||
DataItem data_item;
|
||||
data_item.dataType_ = tdt::TDT_TENSOR;
|
||||
data_item.tensorShape_ = dataShapes;
|
||||
data_item.tensorType_ = datatype;
|
||||
data_item.dataLen_ = ts->SizeInBytes();
|
||||
data_item.dataPtr_ =
|
||||
|
||||
std::shared_ptr<void> dataPtr =
|
||||
std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {});
|
||||
items.emplace_back(data_item);
|
||||
size_t dataLen = ts->SizeInBytes();
|
||||
const dsize_t dims = tsShape.Rank();
|
||||
std::vector<int64_t> dataShape;
|
||||
for (auto i = 0; i < dims; i++) {
|
||||
dataShape.emplace_back(tsShape[i]);
|
||||
}
|
||||
acl_data = acltdtCreateDataItem(ACL_TENSOR_DATA_TENSOR, (tsShape.empty() ? nullptr : &dataShape[0]), dims, datatype,
|
||||
dataPtr.get(), dataLen);
|
||||
if (acl_data == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Create data item failed when send data.");
|
||||
}
|
||||
if (acltdtAddDataItem(acl_dataset, acl_data) != ACL_SUCCESS) {
|
||||
if (acltdtDestroyDataItem(acl_data) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destroy data item failed when send data with type ACL_TENSOR_DATA_TENSOR.";
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Add data item to tdt dataset failed when send data.");
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes
|
||||
<< ", data length is " << ts->Size() << ".";
|
||||
}
|
||||
return SUCCESS;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TdtPlugin::DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item) {
|
||||
if (include_data_item) {
|
||||
for (size_t i = 0; i < acltdtGetDatasetSize(acl_dataset); i++) {
|
||||
if (acltdtDestroyDataItem(acltdtGetDataItem(acl_dataset, i)) != ACL_SUCCESS) {
|
||||
RETURN_STATUS_UNEXPECTED("Destroy data item failed when send data.");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (acltdtDestroyDataset(acl_dataset) != ACL_SUCCESS) {
|
||||
RETURN_STATUS_UNEXPECTED("Destroy tdt dataset failed when send data.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,33 +22,40 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
enum TdtStatus { SUCCESS, FAILED };
|
||||
|
||||
using tdt::DataItem;
|
||||
|
||||
class TdtPlugin {
|
||||
public:
|
||||
static std::shared_ptr<TdtPlugin> GetInstance();
|
||||
|
||||
TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time,
|
||||
tdt::TdtDataType tdt_type = tdt::TDT_TENSOR);
|
||||
Status hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time,
|
||||
acltdtTensorType tdt_type = ACL_TENSOR_DATA_TENSOR);
|
||||
|
||||
TdtPlugin(const std::string &channel_name, int32_t device_id);
|
||||
|
||||
~TdtPlugin();
|
||||
|
||||
private:
|
||||
TdtPlugin() {}
|
||||
Status DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true);
|
||||
|
||||
TdtStatus getTdtType(DataType d_type, std::string &datatype);
|
||||
Status AssembleTensor2AclDataset(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset *acl_dataset);
|
||||
|
||||
TdtStatus translate(const TensorRow &ts_row, std::vector<DataItem> &items);
|
||||
Status getTdtType(DataType d_type, aclDataType &datatype);
|
||||
|
||||
Status translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset);
|
||||
|
||||
void *tdt_handle_ = nullptr;
|
||||
|
||||
acltdtChannelHandle *acl_handle_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -156,15 +156,17 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// of data transmission per time is 256M.
|
||||
/// \param[in] queue_name Channel name (default="", create new unique name).
|
||||
/// \param[in] device_type Type of device (default="", get from MSContext).
|
||||
/// \param[in] device_id id of device (default=1, get from MSContext).
|
||||
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
|
||||
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
|
||||
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
|
||||
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
|
||||
/// of data or not(default=false).
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1,
|
||||
bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false) {
|
||||
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), num_epochs, send_epoch_end,
|
||||
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
|
||||
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
|
||||
bool create_data_info_queue = false) {
|
||||
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
|
||||
total_batches, create_data_info_queue);
|
||||
}
|
||||
|
||||
|
@ -458,8 +460,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
|
||||
|
||||
// Char interface(CharIF) of DeviceQueue
|
||||
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t num_epochs,
|
||||
bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
|
||||
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
|
||||
int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
|
||||
|
||||
// Char interface(CharIF) of Save
|
||||
bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
|
||||
|
|
|
@ -23,8 +23,9 @@
|
|||
#include "minddata/dataset/util/services.h"
|
||||
#endif
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "tdt/status.h"
|
||||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -163,11 +164,10 @@ Status Task::Join(WaitFlag blocking) {
|
|||
if (wait_times > 5 && my_name_.find("DeviceQueueOp") != std::string::npos) {
|
||||
MS_LOG(WARNING) << "Wait " << wait_times << " seconds, "
|
||||
<< "the task: " << my_name_ << " will be destroyed by TdtHostDestory.";
|
||||
int32_t destory_status = tdt::TdtHostDestroy();
|
||||
if (destory_status != TDT_OK_CODE) {
|
||||
MS_LOG(WARNING) << "Destroy tsd failed, status = " << destory_status << ".";
|
||||
if (!TdtHandle::DestroyHandle()) {
|
||||
MS_LOG(WARNING) << "Destroy tdt channel failed.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Destroy tsd success.";
|
||||
MS_LOG(INFO) << "Destroy tdt channel success.";
|
||||
}
|
||||
|
||||
// just wait 30 seconds
|
||||
|
|
|
@ -16,7 +16,8 @@ else()
|
|||
endif()
|
||||
|
||||
if(ENABLE_D)
|
||||
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc")
|
||||
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc"
|
||||
"../../minddata/dataset/engine/tdt/tdt_handle.cc")
|
||||
endif()
|
||||
|
||||
if(ENABLE_CPU)
|
||||
|
|
|
@ -53,8 +53,8 @@
|
|||
#include "runtime/device/ascend/profiling/profiling_callback_register.h"
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "tdt/status.h"
|
||||
#include "minddata/dataset/engine/tdt/tdt_handle.h"
|
||||
using mindspore::dataset::TdtHandle;
|
||||
#endif
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
#include "debug/rdr/running_data_recorder.h"
|
||||
|
@ -692,11 +692,10 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
#ifdef ENABLE_TDTQUE
|
||||
// Run task error, we should call TdtHostDestroy to release tdt to avoid DeviceQueueOp hostPush hung
|
||||
// case1: cpu usage 100% cause thread/process exit, but some tdt thread remain in backend
|
||||
int32_t destory_status = tdt::TdtHostDestroy();
|
||||
if (destory_status != TDT_OK_CODE) {
|
||||
MS_LOG(WARNING) << "Destroy tsd failed, status = " << destory_status << ".";
|
||||
if (!TdtHandle::DestroyHandle()) {
|
||||
MS_LOG(WARNING) << "Destroy tdt channel failed.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Destroy tsd success.";
|
||||
MS_LOG(INFO) << "Destroy tdt channel success.";
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include <atomic>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
|
@ -46,7 +45,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
}
|
||||
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
|
||||
MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
|
||||
MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
return true;
|
||||
}
|
||||
|
@ -56,10 +55,8 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
return true;
|
||||
}
|
||||
|
||||
unsigned int device_id;
|
||||
unsigned int rank_size = 1;
|
||||
|
||||
device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
uint32_t rank_size = 1;
|
||||
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
|
||||
auto rank_size_env = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size_env.empty()) {
|
||||
|
@ -81,14 +78,14 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
}
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||
if (initStatus != TDT_OK_CODE) {
|
||||
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
||||
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle();
|
||||
if (acl_handle == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
|
||||
return false;
|
||||
}
|
||||
ms_context_ptr->tdt_print_ = std::thread(TensorPrint());
|
||||
ms_context_ptr->acl_tdt_print = std::thread(TensorPrint(acl_handle));
|
||||
#endif
|
||||
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = "
|
||||
MS_LOG(INFO) << "Get the acltdt handle successful, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
return true;
|
||||
}
|
||||
|
@ -103,28 +100,34 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
|
||||
if (stopStatus != TDT_OK_CODE) {
|
||||
MS_LOG(EXCEPTION) << "Stop tsd failed, status = " << stopStatus << ".";
|
||||
return false;
|
||||
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle();
|
||||
aclError stopStatus = acltdtStopChannel(acl_handle);
|
||||
if (stopStatus != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed stop acl data channel for host queue ";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
|
||||
}
|
||||
MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op ";
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
int32_t destroyStatus = tdt::TdtHostDestroy();
|
||||
if (destroyStatus != TDT_OK_CODE) {
|
||||
MS_LOG(EXCEPTION) << "Destroy tsd failed, status = " << destroyStatus << ".";
|
||||
return false;
|
||||
aclError destrodStatus = acltdtDestroyChannel(acl_handle);
|
||||
if (destrodStatus != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op ";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op ";
|
||||
}
|
||||
try {
|
||||
if (ms_context_ptr->tdt_print_.joinable()) {
|
||||
MS_LOG(INFO) << "join tdt host receive process";
|
||||
ms_context_ptr->tdt_print_.join();
|
||||
if (ms_context_ptr->acl_tdt_print.joinable()) {
|
||||
MS_LOG(INFO) << "join acl tdt host receive process";
|
||||
ms_context_ptr->acl_tdt_print.join();
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
|
||||
}
|
||||
#endif
|
||||
auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto ret = rtDeviceReset(device_id);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
||||
|
@ -133,10 +136,9 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = "
|
||||
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
#else
|
||||
|
@ -230,7 +232,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
|
|||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
|
||||
}
|
||||
// Disable the global variable acc, only enable it whlie adding training graph in pipeline
|
||||
// Disable the global variable acc, only enable it while adding training graph in pipeline
|
||||
(*ge_options)["ge.exec.variable_acc"] = "0";
|
||||
#endif
|
||||
}
|
||||
|
@ -308,6 +310,7 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
(void)OpenTsd(ms_context_ptr);
|
||||
(void)InitGe(ms_context_ptr);
|
||||
ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "utils/tensorprint_utils.h"
|
||||
|
||||
#ifndef NO_DLIB
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "tdt/tsd_client.h"
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "tdt/data_common.h"
|
||||
#include "runtime/dev.h"
|
||||
#endif
|
||||
|
@ -35,8 +35,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace context {
|
||||
bool OpenTsd(const std::shared_ptr<MsContext> &inst_context);
|
||||
bool CloseTsd(const std::shared_ptr<MsContext> &inst_context, bool force = false);
|
||||
bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr);
|
||||
bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force = false);
|
||||
void SetHcclOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,75 +23,48 @@
|
|||
#include "pybind11/pybind11.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#ifndef NO_DLIB
|
||||
#include "tdt/tsd_client.h"
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "tdt/data_common.h"
|
||||
#endif
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
const char kShapeSeperator[] = ",";
|
||||
const char kShapeScalar[] = "[0]";
|
||||
const char kShapeNone[] = "[]";
|
||||
static std::map<std::string, TypeId> print_type_map = {
|
||||
{"int8", TypeId::kNumberTypeInt8}, {"uint8", TypeId::kNumberTypeUInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16}, {"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"int32", TypeId::kNumberTypeInt32}, {"uint32", TypeId::kNumberTypeUInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64}, {"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bfloat16", TypeId::kNumberTypeFloat16}, {"float", TypeId::kNumberTypeFloat32},
|
||||
{"double", TypeId::kNumberTypeFloat64}, {"bool", TypeId::kNumberTypeBool}};
|
||||
|
||||
static std::map<std::string, size_t> type_size_map = {
|
||||
{"int8", sizeof(int8_t)}, {"uint8", sizeof(uint8_t)}, {"int16", sizeof(int16_t)},
|
||||
{"uint16", sizeof(uint16_t)}, {"int32", sizeof(int32_t)}, {"uint32", sizeof(uint32_t)},
|
||||
{"int64", sizeof(int64_t)}, {"uint64", sizeof(uint64_t)}, {"bfloat16", sizeof(float) / 2},
|
||||
{"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}};
|
||||
#ifndef NO_DLIB
|
||||
static std::map<aclDataType, TypeId> print_acl_data_type_map = {
|
||||
{ACL_INT8, TypeId::kNumberTypeInt8}, {ACL_UINT8, TypeId::kNumberTypeUInt8},
|
||||
{ACL_INT16, TypeId::kNumberTypeInt16}, {ACL_UINT16, TypeId::kNumberTypeUInt16},
|
||||
{ACL_INT32, TypeId::kNumberTypeInt32}, {ACL_UINT32, TypeId::kNumberTypeUInt32},
|
||||
{ACL_INT64, TypeId::kNumberTypeInt64}, {ACL_UINT64, TypeId::kNumberTypeUInt64},
|
||||
{ACL_FLOAT16, TypeId::kNumberTypeFloat16}, {ACL_FLOAT, TypeId::kNumberTypeFloat32},
|
||||
{ACL_DOUBLE, TypeId::kNumberTypeFloat64}, {ACL_BOOL, TypeId::kNumberTypeBool}};
|
||||
|
||||
std::string GetParseType(const std::string &tensorType_) {
|
||||
static const std::map<std::string, std::string> print_parse_map = {
|
||||
{"int8", "Int8"}, {"uint8", "Uint8"}, {"int16", "Int16"}, {"uint16", "Uint16"},
|
||||
{"int32", "Int32"}, {"uint32", "Uint32"}, {"int64", "Int64"}, {"uint64", "Uint64"},
|
||||
{"bfloat16", "Float16"}, {"float", "Float32"}, {"double", "Float64"}, {"bool", "Bool"}};
|
||||
auto type_iter = print_parse_map.find(tensorType_);
|
||||
if (type_iter == print_parse_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "type of tensor need to print is not support " << tensorType_;
|
||||
static std::map<aclDataType, size_t> acl_data_type_size_map = {
|
||||
{ACL_INT8, sizeof(int8_t)}, {ACL_UINT8, sizeof(uint8_t)}, {ACL_INT16, sizeof(int16_t)},
|
||||
{ACL_UINT16, sizeof(uint16_t)}, {ACL_INT32, sizeof(int32_t)}, {ACL_UINT32, sizeof(uint32_t)},
|
||||
{ACL_INT64, sizeof(int64_t)}, {ACL_UINT64, sizeof(uint64_t)}, {ACL_FLOAT16, sizeof(float) / 2},
|
||||
{ACL_FLOAT, sizeof(float)}, {ACL_DOUBLE, sizeof(double)}, {ACL_BOOL, sizeof(bool)}};
|
||||
|
||||
std::string GetParseType(const aclDataType &acl_data_type) {
|
||||
static const std::map<aclDataType, std::string> print_tensor_parse_map = {
|
||||
{ACL_INT8, "Int8"}, {ACL_UINT8, "Uint8"}, {ACL_INT16, "Int16"}, {ACL_UINT16, "Uint16"},
|
||||
{ACL_INT32, "Int32"}, {ACL_UINT32, "Uint32"}, {ACL_INT64, "Int64"}, {ACL_UINT64, "Uint64"},
|
||||
{ACL_FLOAT16, "Float16"}, {ACL_FLOAT, "Float32"}, {ACL_DOUBLE, "Float64"}, {ACL_BOOL, "Bool"}};
|
||||
auto type_iter = print_tensor_parse_map.find(acl_data_type);
|
||||
if (type_iter == print_tensor_parse_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "type of tensor need to print is not support " << acl_data_type;
|
||||
}
|
||||
return type_iter->second;
|
||||
}
|
||||
|
||||
bool ParseTensorShape(const std::string &input_shape_str, ShapeVector *const tensor_shape, size_t *dims) {
|
||||
if (tensor_shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dims);
|
||||
std::string shape_str = input_shape_str;
|
||||
if (shape_str.size() <= 2) {
|
||||
return false;
|
||||
}
|
||||
(void)shape_str.erase(shape_str.begin());
|
||||
shape_str.pop_back();
|
||||
shape_str += kShapeSeperator;
|
||||
string::size_type pos_begin = 0;
|
||||
string::size_type pos_end = shape_str.find(kShapeSeperator);
|
||||
while (pos_end != std::string::npos) {
|
||||
string dim_str = shape_str.substr(pos_begin, pos_end - pos_begin);
|
||||
tensor_shape->emplace_back(std::stoi(dim_str));
|
||||
(*dims) = (*dims) * std::stoul(dim_str);
|
||||
pos_begin = pos_end + sizeof(kShapeSeperator) - 1;
|
||||
pos_end = shape_str.find(kShapeSeperator, pos_begin);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *const print_tensor,
|
||||
const size_t &memory_size) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(print_tensor);
|
||||
auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_ptr);
|
||||
auto cp_ret =
|
||||
memcpy_s(tensor_data_ptr, static_cast<size_t>(print_tensor->data().nbytes()), str_data_ptr, memory_size);
|
||||
|
||||
size_t dest_size = static_cast<size_t>(print_tensor->data().nbytes());
|
||||
size_t target_size = memory_size;
|
||||
|
||||
auto cp_ret = memcpy_s(tensor_data_ptr, dest_size, str_data_ptr, target_size);
|
||||
if (cp_ret != EOK) {
|
||||
MS_LOG(ERROR) << "Print op Failed to copy the memory to py::tensor " << cp_ret;
|
||||
return false;
|
||||
|
@ -100,10 +73,10 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) {
|
||||
void PrintScalarToString(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(acl_data_type) << ", value=";
|
||||
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
|
||||
if constexpr (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
|
||||
const int int_data = static_cast<int>(*data_ptr);
|
||||
|
@ -113,11 +86,12 @@ void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, st
|
|||
}
|
||||
}
|
||||
|
||||
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) {
|
||||
void PrintScalarToBoolString(const char *str_data_ptr, const aclDataType &acl_data_type,
|
||||
std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(acl_data_type) << ", value=";
|
||||
if (*data_ptr) {
|
||||
*buf << "True)\n";
|
||||
} else {
|
||||
|
@ -125,89 +99,99 @@ void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type
|
|||
}
|
||||
}
|
||||
|
||||
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) {
|
||||
void convertDataItem2Scalar(const char *str_data_ptr, const aclDataType &acl_data_type, std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
auto type_iter = print_type_map.find(tensor_type);
|
||||
auto type_iter = print_acl_data_type_map.find(acl_data_type);
|
||||
auto type_id = type_iter->second;
|
||||
if (type_id == TypeId::kNumberTypeBool) {
|
||||
PrintScalarToBoolString(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToBoolString(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt8) {
|
||||
PrintScalarToString<int8_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<int8_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt8) {
|
||||
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<uint8_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt16) {
|
||||
PrintScalarToString<int16_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<int16_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt16) {
|
||||
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<uint16_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt32) {
|
||||
PrintScalarToString<int32_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<int32_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt32) {
|
||||
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<uint32_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt64) {
|
||||
PrintScalarToString<int64_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<int64_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt64) {
|
||||
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<uint64_t>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeFloat16) {
|
||||
PrintScalarToString<float16>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<float16>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeFloat32) {
|
||||
PrintScalarToString<float>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<float>(str_data_ptr, acl_data_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeFloat64) {
|
||||
PrintScalarToString<double>(str_data_ptr, tensor_type, buf);
|
||||
PrintScalarToString<double>(str_data_ptr, acl_data_type, buf);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupported data type: " << tensor_type << ".";
|
||||
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupported data type: " << GetParseType(acl_data_type) << ".";
|
||||
}
|
||||
}
|
||||
|
||||
bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
|
||||
auto type_iter = type_size_map.find(tensor_type);
|
||||
if (type_iter == type_size_map.end()) {
|
||||
bool judgeLengthValid(const size_t str_len, const aclDataType &acl_data_type) {
|
||||
auto type_iter = acl_data_type_size_map.find(acl_data_type);
|
||||
if (type_iter == acl_data_type_size_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
|
||||
}
|
||||
return str_len == type_iter->second;
|
||||
}
|
||||
|
||||
#ifndef NO_DLIB
|
||||
bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
||||
bool ConvertDataset2Tensor(acltdtDataset *acl_dataset) {
|
||||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
std::ostringstream buf;
|
||||
bool ret_end_sequence = false;
|
||||
for (auto &item : items) {
|
||||
if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) {
|
||||
|
||||
size_t acl_dataset_size = acltdtGetDatasetSize(acl_dataset);
|
||||
|
||||
for (size_t i = 0; i < acl_dataset_size; i++) {
|
||||
acltdtDataItem *item = acltdtGetDataItem(acl_dataset, i);
|
||||
if (acltdtGetTensorTypeFromItem(item) == ACL_TENSOR_DATA_END_OF_SEQUENCE) {
|
||||
ret_end_sequence = true;
|
||||
MS_LOG(INFO) << "end of sequence" << std::endl;
|
||||
break;
|
||||
}
|
||||
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) {
|
||||
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
|
||||
|
||||
size_t dim_num = acltdtGetDimNumFromItem(item);
|
||||
void *acl_addr = acltdtGetDataAddrFromItem(item);
|
||||
size_t acl_data_size = acltdtGetDataSizeFromItem(item);
|
||||
aclDataType acl_data_type = acltdtGetDataTypeFromItem(item);
|
||||
char *acl_data = reinterpret_cast<char *>(acl_addr);
|
||||
acl_data = const_cast<char *>(reinterpret_cast<std::string *>(acl_data)->c_str());
|
||||
MS_EXCEPTION_IF_NULL(acl_data);
|
||||
|
||||
ShapeVector tensorShape;
|
||||
tensorShape.resize(dim_num);
|
||||
|
||||
if (acltdtGetDimsFromItem(item, tensorShape.data(), dim_num) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "ACL failed get dim-size from acl channel data";
|
||||
}
|
||||
|
||||
if ((tensorShape.size() == 1 && tensorShape[0] == 0) || tensorShape.size() == 0) {
|
||||
if (!judgeLengthValid(acl_data_size, acl_data_type)) {
|
||||
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||
}
|
||||
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf);
|
||||
convertDataItem2Scalar(acl_data, acl_data_type, &buf);
|
||||
continue;
|
||||
}
|
||||
|
||||
ShapeVector tensor_shape;
|
||||
size_t totaldims = 1;
|
||||
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
|
||||
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.tensorType_ == "string") {
|
||||
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);
|
||||
if (acl_data_type == ACL_STRING) {
|
||||
std::string data(reinterpret_cast<const char *>(acl_data), acl_data_size);
|
||||
buf << data << std::endl;
|
||||
} else {
|
||||
auto type_iter = print_type_map.find(item.tensorType_);
|
||||
if (type_iter == print_type_map.end()) {
|
||||
MS_LOG(ERROR) << "type of tensor need to print is not support " << item.tensorType_;
|
||||
auto type_iter = print_acl_data_type_map.find(acl_data_type);
|
||||
if (type_iter == print_acl_data_type_map.end()) {
|
||||
MS_LOG(ERROR) << "type of tensor need to print is not support " << GetParseType(acl_data_type);
|
||||
continue;
|
||||
}
|
||||
auto type_id = type_iter->second;
|
||||
mindspore::tensor::Tensor print_tensor(type_id, tensor_shape);
|
||||
auto memory_size = totaldims * type_size_map[item.tensorType_];
|
||||
if (PrintTensorToString(str_data_ptr->data(), &print_tensor, memory_size)) {
|
||||
mindspore::tensor::Tensor print_tensor(type_id, tensorShape);
|
||||
if (PrintTensorToString(acl_data, &print_tensor, acl_data_size)) {
|
||||
buf << print_tensor.ToStringNoLimit() << std::endl;
|
||||
}
|
||||
}
|
||||
|
@ -216,44 +200,63 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
|||
return ret_end_sequence;
|
||||
}
|
||||
|
||||
bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::string &print_file_path, prntpb::Print print,
|
||||
std::fstream *output) {
|
||||
bool SaveDataset2File(acltdtDataset *acl_dataset, const std::string &print_file_path, prntpb::Print print,
|
||||
std::fstream *output) {
|
||||
bool ret_end_thread = false;
|
||||
for (auto &item : items) {
|
||||
if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) {
|
||||
|
||||
for (size_t i = 0; i < acltdtGetDatasetSize(acl_dataset); i++) {
|
||||
acltdtDataItem *item = acltdtGetDataItem(acl_dataset, i);
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
acltdtTensorType acl_tensor_type = acltdtGetTensorTypeFromItem(item);
|
||||
|
||||
if (acl_tensor_type == ACL_TENSOR_DATA_END_OF_SEQUENCE) {
|
||||
MS_LOG(INFO) << "Acl channel received end-of-sequence for print op.";
|
||||
ret_end_thread = true;
|
||||
break;
|
||||
} else if (acl_tensor_type == ACL_TENSOR_DATA_ABNORMAL) {
|
||||
MS_LOG(INFO) << "Acl channel received abnormal for print op.";
|
||||
return true;
|
||||
} else if (acl_tensor_type == ACL_TENSOR_DATA_UNDEFINED) {
|
||||
MS_LOG(INFO) << "Acl channel received undefined message type for print op.";
|
||||
return false;
|
||||
}
|
||||
|
||||
prntpb::Print_Value *value = print.add_value();
|
||||
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) {
|
||||
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
|
||||
size_t dim_num = acltdtGetDimNumFromItem(item);
|
||||
void *acl_addr = acltdtGetDataAddrFromItem(item);
|
||||
size_t acl_data_size = acltdtGetDataSizeFromItem(item);
|
||||
aclDataType acl_data_type = acltdtGetDataTypeFromItem(item);
|
||||
char *acl_data = reinterpret_cast<char *>(acl_addr);
|
||||
MS_EXCEPTION_IF_NULL(acl_data);
|
||||
|
||||
ShapeVector tensorShape;
|
||||
tensorShape.resize(dim_num);
|
||||
|
||||
if (acltdtGetDimsFromItem(item, tensorShape.data(), dim_num) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "ACL failed get dim-size from acl channel data";
|
||||
}
|
||||
|
||||
if ((tensorShape.size() == 1 && tensorShape[0] == 0) || tensorShape.size() == 0) {
|
||||
if (!judgeLengthValid(acl_data_size, acl_data_type)) {
|
||||
MS_LOG(ERROR) << "Print op receive data length is invalid.";
|
||||
ret_end_thread = true;
|
||||
}
|
||||
}
|
||||
|
||||
ShapeVector tensor_shape;
|
||||
size_t totaldims = 1;
|
||||
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
|
||||
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
|
||||
ret_end_thread = true;
|
||||
}
|
||||
|
||||
if (item.tensorType_ == "string") {
|
||||
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);
|
||||
if (acl_data_type == ACL_STRING) {
|
||||
std::string data(reinterpret_cast<const char *>(acl_data), acl_data_size);
|
||||
value->set_desc(data);
|
||||
} else {
|
||||
auto parse_type = GetParseType(item.tensorType_);
|
||||
auto parse_type = GetParseType(acl_data_type);
|
||||
prntpb::TensorProto *tensor = value->mutable_tensor();
|
||||
if (!(item.tensorShape_ == kShapeScalar) && !(item.tensorShape_ == kShapeNone)) {
|
||||
for (const auto &dim : tensor_shape) {
|
||||
if (tensorShape.size() > 1 || (tensorShape.size() == 1 && tensorShape[0] != 1)) {
|
||||
for (const auto &dim : tensorShape) {
|
||||
tensor->add_dims(static_cast<::google::protobuf::int64>(dim));
|
||||
}
|
||||
}
|
||||
|
||||
tensor->set_tensor_type(parse_type);
|
||||
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);
|
||||
std::string data(reinterpret_cast<const char *>(acl_data), acl_data_size);
|
||||
tensor->set_tensor_content(data);
|
||||
}
|
||||
|
||||
|
@ -274,29 +277,37 @@ void TensorPrint::operator()() {
|
|||
std::string print_file_path = ms_context->get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
|
||||
if (print_file_path == "") {
|
||||
while (true) {
|
||||
std::vector<tdt::DataItem> bundle;
|
||||
if (tdt::TdtHostPopData("_npu_log", bundle) != 0) {
|
||||
acltdtDataset *acl_dataset = acltdtCreateDataset();
|
||||
if (acl_dataset == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed create acl dateaset.";
|
||||
}
|
||||
if (acltdtReceiveTensor(acl_handle_, acl_dataset, -1 /* no timeout */) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Acltdt receive tensor failed";
|
||||
break;
|
||||
}
|
||||
if (ConvertDataItem2Tensor(bundle)) {
|
||||
if (ConvertDataset2Tensor(acl_dataset)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::fstream output(print_file_path, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||
while (true) {
|
||||
std::vector<tdt::DataItem> bundle;
|
||||
if (tdt::TdtHostPopData("_npu_log", bundle) != 0) {
|
||||
acltdtDataset *acl_dataset = acltdtCreateDataset();
|
||||
if (acl_dataset == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed create acl dateaset.";
|
||||
}
|
||||
if (acltdtReceiveTensor(acl_handle_, acl_dataset, -1 /* no timeout */) != ACL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Acltdt receive tensor failed";
|
||||
break;
|
||||
}
|
||||
if (SaveDataItem2File(bundle, print_file_path, print, &output)) {
|
||||
if (SaveDataset2File(acl_dataset, print_file_path, print, &output)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
output.close();
|
||||
std::string path_string = print_file_path;
|
||||
if (chmod(common::SafeCStr(path_string), S_IRUSR) == -1) {
|
||||
MS_LOG(ERROR) << "Modify file:" << print_file_path << " to r fail.";
|
||||
MS_LOG(ERROR) << "Modify file:" << print_file_path << " to fail.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,9 +20,10 @@
|
|||
#include <map>
|
||||
#include "ir/dtype/type.h"
|
||||
#ifndef NO_DLIB
|
||||
#include "acl/acl_tdt.h"
|
||||
#include "tdt/tsd_client.h"
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "tdt/data_common.h"
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
#include "proto/print.pb.h"
|
||||
#include "utils/ms_context.h"
|
||||
#endif
|
||||
|
@ -32,7 +33,11 @@ class TensorPrint {
|
|||
TensorPrint() {}
|
||||
~TensorPrint() = default;
|
||||
#ifndef NO_DLIB
|
||||
explicit TensorPrint(acltdtChannelHandle *acl_handle) { acl_handle_ = acl_handle; }
|
||||
void operator()();
|
||||
|
||||
private:
|
||||
acltdtChannelHandle *acl_handle_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,6 +57,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
} else {
|
||||
set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
|
||||
}
|
||||
|
||||
set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
|
||||
set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
|
||||
set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
|
||||
|
@ -123,4 +124,22 @@ bool MsContext::enable_dump_ir() const {
|
|||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() {
|
||||
if (acl_handle == nullptr) {
|
||||
std::string kReceivePrefix = "TF_RECEIVE_";
|
||||
std::string channel_name = "_npu_log";
|
||||
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
|
||||
if (acl_handle == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name;
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name;
|
||||
return acl_handle;
|
||||
}
|
||||
return acl_handle;
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,10 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
#ifndef NO_DLIB
|
||||
#include "acl/acl_tdt.h"
|
||||
#endif
|
||||
namespace mindspore {
|
||||
enum MsBackendPolicy {
|
||||
kMsBackendGeOnly = 0,
|
||||
|
@ -132,11 +135,13 @@ class MsContext {
|
|||
bool enable_dump_ir() const;
|
||||
std::string backend_policy() const;
|
||||
bool set_backend_policy(const std::string &policy);
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
acltdtChannelHandle *get_acl_tdt_channel_handle();
|
||||
#endif
|
||||
static void device_seter(DeviceSeter device) { seter_ = device; }
|
||||
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
|
||||
|
||||
std::thread tdt_print_;
|
||||
std::thread acl_tdt_print;
|
||||
|
||||
template <typename T>
|
||||
void set_param(MsCtxParam param, const T &value) {
|
||||
|
@ -171,6 +176,9 @@ class MsContext {
|
|||
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
|
||||
|
||||
MsBackendPolicy backend_policy_;
|
||||
#ifdef ENABLE_TDTQUE
|
||||
acltdtChannelHandle *acl_handle = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
||||
// set method implementation for type bool/int/uint32_t/float/std::string
|
||||
|
|
|
@ -2722,10 +2722,11 @@ class TransferDataset(Dataset):
|
|||
|
||||
def parse(self, children=None):
|
||||
total_batch = 0
|
||||
device_id = context.get_context("device_id")
|
||||
if hasattr(self.children[0], "__total_batch__"):
|
||||
total_batch = self.children[0].__total_batch__
|
||||
return cde.TransferNode(children[0], self.queue_name, self.device_type, self._send_epoch_end, total_batch,
|
||||
self._create_data_info_queue)
|
||||
return cde.TransferNode(children[0], self.queue_name, self.device_type, device_id, self._send_epoch_end,
|
||||
total_batch, self._create_data_info_queue)
|
||||
|
||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||
raise RuntimeError("TransferDataset is not iterable.")
|
||||
|
|
Loading…
Reference in New Issue