!42676 use dynamic GetNext in Pynative mode
Merge pull request !42676 from wYann/data_queue
This commit is contained in:
commit
326279350e
|
@ -105,6 +105,8 @@ class BACKEND_EXPORT DataQueueMgr {
|
||||||
|
|
||||||
bool IsClosed() const;
|
bool IsClosed() const;
|
||||||
|
|
||||||
|
bool IsCreated(const std::string &channel_name) const;
|
||||||
|
|
||||||
bool Destroy();
|
bool Destroy();
|
||||||
|
|
||||||
// call for Release GPU Resources
|
// call for Release GPU Resources
|
||||||
|
@ -117,9 +119,9 @@ class BACKEND_EXPORT DataQueueMgr {
|
||||||
|
|
||||||
size_t Capacity(const std::string &channel_name);
|
size_t Capacity(const std::string &channel_name);
|
||||||
|
|
||||||
private:
|
void Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue);
|
||||||
inline bool isCreated(const std::string &channel_name) const;
|
|
||||||
|
|
||||||
|
private:
|
||||||
DataQueueMgr(const DataQueueMgr &) = delete;
|
DataQueueMgr(const DataQueueMgr &) = delete;
|
||||||
DataQueueMgr &operator=(const DataQueueMgr &) = delete;
|
DataQueueMgr &operator=(const DataQueueMgr &) = delete;
|
||||||
|
|
||||||
|
|
|
@ -77,8 +77,8 @@ DataQueueOp::DataQueueOp(const std::string channel_name, DeviceType device_type,
|
||||||
#ifdef WITH_BACKEND
|
#ifdef WITH_BACKEND
|
||||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||||
if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||||
(void)device::DataQueueMgr::GetInstance().Create(channel_name, {}, 0);
|
ascend_data_queue_ =
|
||||||
ascend_data_queue_ = device::DataQueueMgr::GetInstance().GetDataQueue(channel_name)->Queue();
|
device::DataQueueMgr::GetInstance().CreateDataQueue(kAscendDevice, channel_name, dynamic_shape_, 0, {});
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
|
|
|
@ -200,12 +200,15 @@ AscendTdtQueue::AscendTdtQueue(const std::string &channel_name) : DataQueue(chan
|
||||||
}
|
}
|
||||||
tdt_handle::AddHandle(&acl_handle_, nullptr);
|
tdt_handle::AddHandle(&acl_handle_, nullptr);
|
||||||
}
|
}
|
||||||
wingman_queue_ = std::make_shared<BlockingQueue>();
|
|
||||||
|
// a wingman of tdt to help with transferring data shapes on host
|
||||||
|
auto wingman_queue = std::make_shared<BlockingQueue>();
|
||||||
std::shared_ptr<DataQueue> data_queue = std::make_shared<WingmanQueue>(channel_name);
|
std::shared_ptr<DataQueue> data_queue = std::make_shared<WingmanQueue>(channel_name);
|
||||||
auto rt = wingman_queue_->Create(data_queue);
|
auto rt = wingman_queue->Create(data_queue);
|
||||||
if (rt != DataQueueStatus::SUCCESS) {
|
if (rt != DataQueueStatus::SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Wingman queue: " << channel_name << "create failed: " << rt;
|
MS_LOG(EXCEPTION) << "Wingman queue: " << channel_name << "create failed: " << rt;
|
||||||
}
|
}
|
||||||
|
DataQueueMgr::GetInstance().Manage(channel_name, wingman_queue);
|
||||||
}
|
}
|
||||||
|
|
||||||
AscendTdtQueue::~AscendTdtQueue() {
|
AscendTdtQueue::~AscendTdtQueue() {
|
||||||
|
@ -218,6 +221,9 @@ AscendTdtQueue::~AscendTdtQueue() {
|
||||||
acl_handle_ = nullptr;
|
acl_handle_ = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (DataQueueMgr::GetInstance().IsCreated(channel_name_)) {
|
||||||
|
DataQueueMgr::GetInstance().Free(channel_name_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendTdtQueue::IsOpen() const { return !tdt_handle::IsClosed(); }
|
bool AscendTdtQueue::IsOpen() const { return !tdt_handle::IsClosed(); }
|
||||||
|
@ -257,8 +263,9 @@ DataQueueStatus AscendTdtQueue::Push(std::vector<DataQueueItem> data) {
|
||||||
MS_LOG(EXCEPTION) << "Tdt Send data failed. The details refer to 'Ascend Error Message'."
|
MS_LOG(EXCEPTION) << "Tdt Send data failed. The details refer to 'Ascend Error Message'."
|
||||||
<< ascend::GetErrorMessage(true);
|
<< ascend::GetErrorMessage(true);
|
||||||
}
|
}
|
||||||
if (wingman_queue_->IsOpen() && !data.empty()) {
|
auto wingman = DataQueueMgr::GetInstance().GetDataQueue(channel_name_);
|
||||||
wingman_queue_->Push(data);
|
if (wingman != nullptr && wingman->IsOpen() && !data.empty()) {
|
||||||
|
wingman->Push(data);
|
||||||
}
|
}
|
||||||
return DataQueueStatus::SUCCESS;
|
return DataQueueStatus::SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -649,14 +656,7 @@ void WingmanQueue::Close() {
|
||||||
std::shared_ptr<BlockingQueue> GetTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
std::shared_ptr<BlockingQueue> GetTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
||||||
if (common::AnfAlgo::GetCNodeName(node) != kGetNextOpName) return nullptr;
|
if (common::AnfAlgo::GetCNodeName(node) != kGetNextOpName) return nullptr;
|
||||||
auto queue_name = common::AnfAlgo::GetNodeAttr<std::string>(node, "shared_name");
|
auto queue_name = common::AnfAlgo::GetNodeAttr<std::string>(node, "shared_name");
|
||||||
auto data_queue = DataQueueMgr::GetInstance().GetDataQueue(queue_name);
|
return DataQueueMgr::GetInstance().GetDataQueue(queue_name);
|
||||||
if (data_queue) {
|
|
||||||
auto tdt_queue = dynamic_cast<device::AscendTdtQueue *>(data_queue->Queue().get());
|
|
||||||
if (tdt_queue) {
|
|
||||||
return tdt_queue->GetWingMan();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
||||||
|
@ -670,10 +670,6 @@ void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
||||||
namespace {
|
namespace {
|
||||||
std::shared_ptr<DataQueue> CreateAscendDataQueue(const std::string &channel_name, bool dynamic_shape, size_t capacity,
|
std::shared_ptr<DataQueue> CreateAscendDataQueue(const std::string &channel_name, bool dynamic_shape, size_t capacity,
|
||||||
const std::vector<size_t> &) {
|
const std::vector<size_t> &) {
|
||||||
if (dynamic_shape) {
|
|
||||||
return std::make_shared<AscendDataQueueDynamic>(channel_name, capacity);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t is_heterogeneous = 0;
|
int32_t is_heterogeneous = 0;
|
||||||
(void)rtGetIsHeterogenous(&is_heterogeneous);
|
(void)rtGetIsHeterogenous(&is_heterogeneous);
|
||||||
if (is_heterogeneous != 0) {
|
if (is_heterogeneous != 0) {
|
||||||
|
|
|
@ -83,7 +83,6 @@ class AscendTdtQueue : public DataQueue {
|
||||||
DataQueueStatus Push(std::vector<DataQueueItem> data) override;
|
DataQueueStatus Push(std::vector<DataQueueItem> data) override;
|
||||||
DataQueueStatus Front(std::vector<DataQueueItem> *data) const override { return DataQueueStatus::SUCCESS; }
|
DataQueueStatus Front(std::vector<DataQueueItem> *data) const override { return DataQueueStatus::SUCCESS; }
|
||||||
DataQueueStatus Pop() override { return DataQueueStatus::SUCCESS; }
|
DataQueueStatus Pop() override { return DataQueueStatus::SUCCESS; }
|
||||||
std::shared_ptr<BlockingQueue> GetWingMan() { return wingman_queue_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true) const;
|
void DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true) const;
|
||||||
|
@ -91,7 +90,6 @@ class AscendTdtQueue : public DataQueue {
|
||||||
void ParseType(aclDataType acl_data_type, std::string *data_type) const;
|
void ParseType(aclDataType acl_data_type, std::string *data_type) const;
|
||||||
bool Translate(const std::vector<DataQueueItem> &data, acltdtDataset **output_acl_dataset) const;
|
bool Translate(const std::vector<DataQueueItem> &data, acltdtDataset **output_acl_dataset) const;
|
||||||
|
|
||||||
std::shared_ptr<BlockingQueue> wingman_queue_;
|
|
||||||
acltdtChannelHandle *acl_handle_;
|
acltdtChannelHandle *acl_handle_;
|
||||||
uint32_t device_id_;
|
uint32_t device_id_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -191,7 +191,8 @@ void DynamicAicpuOpKernelMod::SyncData() {
|
||||||
MS_LOG(EXCEPTION) << "The cnode is not dynamic shape:" << cnode->fullname_with_scope();
|
MS_LOG(EXCEPTION) << "The cnode is not dynamic shape:" << cnode->fullname_with_scope();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE) {
|
if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE ||
|
||||||
|
common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
|
||||||
MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " update op skip.";
|
MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " update op skip.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,10 @@ std::shared_ptr<DataQueue> DataQueueMgr::CreateDataQueue(const std::string &devi
|
||||||
return iter->second(channel_name, dynamic_shape, capacity, shape);
|
return iter->second(channel_name, dynamic_shape, capacity, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DataQueueMgr::Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue) {
|
||||||
|
(void)name_queue_map_.insert(std::make_pair(channel_name, queue));
|
||||||
|
}
|
||||||
|
|
||||||
DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector<size_t> &shape,
|
DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector<size_t> &shape,
|
||||||
const size_t capacity) {
|
const size_t capacity) {
|
||||||
MS_LOG(INFO) << "Data queue: " << channel_name << " created";
|
MS_LOG(INFO) << "Data queue: " << channel_name << " created";
|
||||||
|
@ -181,7 +185,7 @@ bool DataQueueMgr::IsInit() const { return init_; }
|
||||||
|
|
||||||
bool DataQueueMgr::IsClosed() const { return closed_; }
|
bool DataQueueMgr::IsClosed() const { return closed_; }
|
||||||
|
|
||||||
inline bool DataQueueMgr::isCreated(const std::string &channel_name) const {
|
bool DataQueueMgr::IsCreated(const std::string &channel_name) const {
|
||||||
return name_queue_map_.find(channel_name) != name_queue_map_.end();
|
return name_queue_map_.find(channel_name) != name_queue_map_.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -133,6 +133,9 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
||||||
if network.get_inputs() and None not in network.get_inputs():
|
if network.get_inputs() and None not in network.get_inputs():
|
||||||
_check_inputs(network, dataset_shapes, dataset_types)
|
_check_inputs(network, dataset_shapes, dataset_types)
|
||||||
min_shapes, max_shapes = None, None
|
min_shapes, max_shapes = None, None
|
||||||
|
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
|
dataset_shapes = tuple([(-2,)] * len(dataset_shapes))
|
||||||
|
min_shapes, max_shapes = None, None
|
||||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
||||||
queue_name, min_shapes, max_shapes)
|
queue_name, min_shapes, max_shapes)
|
||||||
return network
|
return network
|
||||||
|
|
Loading…
Reference in New Issue