!42676 use dynamic GetNext in Pynative mode

Merge pull request !42676 from wYann/data_queue
This commit is contained in:
i-robot 2022-10-10 12:10:07 +00:00 committed by Gitee
commit 326279350e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 28 additions and 24 deletions

View File

@ -105,6 +105,8 @@ class BACKEND_EXPORT DataQueueMgr {
bool IsClosed() const;
bool IsCreated(const std::string &channel_name) const;
bool Destroy();
// call for Release GPU Resources
@ -117,9 +119,9 @@ class BACKEND_EXPORT DataQueueMgr {
size_t Capacity(const std::string &channel_name);
private:
inline bool isCreated(const std::string &channel_name) const;
void Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue);
private:
DataQueueMgr(const DataQueueMgr &) = delete;
DataQueueMgr &operator=(const DataQueueMgr &) = delete;

View File

@ -77,8 +77,8 @@ DataQueueOp::DataQueueOp(const std::string channel_name, DeviceType device_type,
#ifdef WITH_BACKEND
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
(void)device::DataQueueMgr::GetInstance().Create(channel_name, {}, 0);
ascend_data_queue_ = device::DataQueueMgr::GetInstance().GetDataQueue(channel_name)->Queue();
ascend_data_queue_ =
device::DataQueueMgr::GetInstance().CreateDataQueue(kAscendDevice, channel_name, dynamic_shape_, 0, {});
}
#endif
#ifdef ENABLE_DUMP_IR

View File

@ -200,12 +200,15 @@ AscendTdtQueue::AscendTdtQueue(const std::string &channel_name) : DataQueue(chan
}
tdt_handle::AddHandle(&acl_handle_, nullptr);
}
wingman_queue_ = std::make_shared<BlockingQueue>();
// a wingman of tdt to help with transferring data shapes on host
auto wingman_queue = std::make_shared<BlockingQueue>();
std::shared_ptr<DataQueue> data_queue = std::make_shared<WingmanQueue>(channel_name);
auto rt = wingman_queue_->Create(data_queue);
auto rt = wingman_queue->Create(data_queue);
if (rt != DataQueueStatus::SUCCESS) {
MS_LOG(EXCEPTION) << "Wingman queue: " << channel_name << "create failed: " << rt;
}
DataQueueMgr::GetInstance().Manage(channel_name, wingman_queue);
}
AscendTdtQueue::~AscendTdtQueue() {
@ -218,6 +221,9 @@ AscendTdtQueue::~AscendTdtQueue() {
acl_handle_ = nullptr;
}
}
if (DataQueueMgr::GetInstance().IsCreated(channel_name_)) {
DataQueueMgr::GetInstance().Free(channel_name_);
}
}
bool AscendTdtQueue::IsOpen() const { return !tdt_handle::IsClosed(); }
@ -257,8 +263,9 @@ DataQueueStatus AscendTdtQueue::Push(std::vector<DataQueueItem> data) {
MS_LOG(EXCEPTION) << "Tdt Send data failed. The details refer to 'Ascend Error Message'."
<< ascend::GetErrorMessage(true);
}
if (wingman_queue_->IsOpen() && !data.empty()) {
wingman_queue_->Push(data);
auto wingman = DataQueueMgr::GetInstance().GetDataQueue(channel_name_);
if (wingman != nullptr && wingman->IsOpen() && !data.empty()) {
wingman->Push(data);
}
return DataQueueStatus::SUCCESS;
}
@ -649,14 +656,7 @@ void WingmanQueue::Close() {
std::shared_ptr<BlockingQueue> GetTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
if (common::AnfAlgo::GetCNodeName(node) != kGetNextOpName) return nullptr;
auto queue_name = common::AnfAlgo::GetNodeAttr<std::string>(node, "shared_name");
auto data_queue = DataQueueMgr::GetInstance().GetDataQueue(queue_name);
if (data_queue) {
auto tdt_queue = dynamic_cast<device::AscendTdtQueue *>(data_queue->Queue().get());
if (tdt_queue) {
return tdt_queue->GetWingMan();
}
}
return nullptr;
return DataQueueMgr::GetInstance().GetDataQueue(queue_name);
}
void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
@ -670,10 +670,6 @@ void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
namespace {
std::shared_ptr<DataQueue> CreateAscendDataQueue(const std::string &channel_name, bool dynamic_shape, size_t capacity,
const std::vector<size_t> &) {
if (dynamic_shape) {
return std::make_shared<AscendDataQueueDynamic>(channel_name, capacity);
}
int32_t is_heterogeneous = 0;
(void)rtGetIsHeterogenous(&is_heterogeneous);
if (is_heterogeneous != 0) {

View File

@ -83,7 +83,6 @@ class AscendTdtQueue : public DataQueue {
DataQueueStatus Push(std::vector<DataQueueItem> data) override;
DataQueueStatus Front(std::vector<DataQueueItem> *data) const override { return DataQueueStatus::SUCCESS; }
DataQueueStatus Pop() override { return DataQueueStatus::SUCCESS; }
std::shared_ptr<BlockingQueue> GetWingMan() { return wingman_queue_; }
private:
void DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true) const;
@ -91,7 +90,6 @@ class AscendTdtQueue : public DataQueue {
void ParseType(aclDataType acl_data_type, std::string *data_type) const;
bool Translate(const std::vector<DataQueueItem> &data, acltdtDataset **output_acl_dataset) const;
std::shared_ptr<BlockingQueue> wingman_queue_;
acltdtChannelHandle *acl_handle_;
uint32_t device_id_;
};

View File

@ -191,7 +191,8 @@ void DynamicAicpuOpKernelMod::SyncData() {
MS_LOG(EXCEPTION) << "The cnode is not dynamic shape:" << cnode->fullname_with_scope();
}
if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE) {
if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE ||
common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " update op skip.";
return;
}

View File

@ -49,6 +49,10 @@ std::shared_ptr<DataQueue> DataQueueMgr::CreateDataQueue(const std::string &devi
return iter->second(channel_name, dynamic_shape, capacity, shape);
}
void DataQueueMgr::Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue) {
(void)name_queue_map_.insert(std::make_pair(channel_name, queue));
}
DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector<size_t> &shape,
const size_t capacity) {
MS_LOG(INFO) << "Data queue: " << channel_name << " created";
@ -181,7 +185,7 @@ bool DataQueueMgr::IsInit() const { return init_; }
bool DataQueueMgr::IsClosed() const { return closed_; }
inline bool DataQueueMgr::isCreated(const std::string &channel_name) const {
bool DataQueueMgr::IsCreated(const std::string &channel_name) const {
return name_queue_map_.find(channel_name) != name_queue_map_.end();
}

View File

@ -133,6 +133,9 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
if network.get_inputs() and None not in network.get_inputs():
_check_inputs(network, dataset_shapes, dataset_types)
min_shapes, max_shapes = None, None
elif context.get_context("mode") == context.PYNATIVE_MODE:
dataset_shapes = tuple([(-2,)] * len(dataset_shapes))
min_shapes, max_shapes = None, None
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
queue_name, min_shapes, max_shapes)
return network