!42676 use dynamic GetNext in Pynative mode
Merge pull request !42676 from wYann/data_queue
This commit is contained in:
commit
326279350e
|
@ -105,6 +105,8 @@ class BACKEND_EXPORT DataQueueMgr {
|
|||
|
||||
bool IsClosed() const;
|
||||
|
||||
bool IsCreated(const std::string &channel_name) const;
|
||||
|
||||
bool Destroy();
|
||||
|
||||
// call for Release GPU Resources
|
||||
|
@ -117,9 +119,9 @@ class BACKEND_EXPORT DataQueueMgr {
|
|||
|
||||
size_t Capacity(const std::string &channel_name);
|
||||
|
||||
private:
|
||||
inline bool isCreated(const std::string &channel_name) const;
|
||||
void Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue);
|
||||
|
||||
private:
|
||||
DataQueueMgr(const DataQueueMgr &) = delete;
|
||||
DataQueueMgr &operator=(const DataQueueMgr &) = delete;
|
||||
|
||||
|
|
|
@ -77,8 +77,8 @@ DataQueueOp::DataQueueOp(const std::string channel_name, DeviceType device_type,
|
|||
#ifdef WITH_BACKEND
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
(void)device::DataQueueMgr::GetInstance().Create(channel_name, {}, 0);
|
||||
ascend_data_queue_ = device::DataQueueMgr::GetInstance().GetDataQueue(channel_name)->Queue();
|
||||
ascend_data_queue_ =
|
||||
device::DataQueueMgr::GetInstance().CreateDataQueue(kAscendDevice, channel_name, dynamic_shape_, 0, {});
|
||||
}
|
||||
#endif
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
|
|
@ -200,12 +200,15 @@ AscendTdtQueue::AscendTdtQueue(const std::string &channel_name) : DataQueue(chan
|
|||
}
|
||||
tdt_handle::AddHandle(&acl_handle_, nullptr);
|
||||
}
|
||||
wingman_queue_ = std::make_shared<BlockingQueue>();
|
||||
|
||||
// a wingman of tdt to help with transferring data shapes on host
|
||||
auto wingman_queue = std::make_shared<BlockingQueue>();
|
||||
std::shared_ptr<DataQueue> data_queue = std::make_shared<WingmanQueue>(channel_name);
|
||||
auto rt = wingman_queue_->Create(data_queue);
|
||||
auto rt = wingman_queue->Create(data_queue);
|
||||
if (rt != DataQueueStatus::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Wingman queue: " << channel_name << "create failed: " << rt;
|
||||
}
|
||||
DataQueueMgr::GetInstance().Manage(channel_name, wingman_queue);
|
||||
}
|
||||
|
||||
AscendTdtQueue::~AscendTdtQueue() {
|
||||
|
@ -218,6 +221,9 @@ AscendTdtQueue::~AscendTdtQueue() {
|
|||
acl_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (DataQueueMgr::GetInstance().IsCreated(channel_name_)) {
|
||||
DataQueueMgr::GetInstance().Free(channel_name_);
|
||||
}
|
||||
}
|
||||
|
||||
bool AscendTdtQueue::IsOpen() const { return !tdt_handle::IsClosed(); }
|
||||
|
@ -257,8 +263,9 @@ DataQueueStatus AscendTdtQueue::Push(std::vector<DataQueueItem> data) {
|
|||
MS_LOG(EXCEPTION) << "Tdt Send data failed. The details refer to 'Ascend Error Message'."
|
||||
<< ascend::GetErrorMessage(true);
|
||||
}
|
||||
if (wingman_queue_->IsOpen() && !data.empty()) {
|
||||
wingman_queue_->Push(data);
|
||||
auto wingman = DataQueueMgr::GetInstance().GetDataQueue(channel_name_);
|
||||
if (wingman != nullptr && wingman->IsOpen() && !data.empty()) {
|
||||
wingman->Push(data);
|
||||
}
|
||||
return DataQueueStatus::SUCCESS;
|
||||
}
|
||||
|
@ -649,14 +656,7 @@ void WingmanQueue::Close() {
|
|||
std::shared_ptr<BlockingQueue> GetTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
||||
if (common::AnfAlgo::GetCNodeName(node) != kGetNextOpName) return nullptr;
|
||||
auto queue_name = common::AnfAlgo::GetNodeAttr<std::string>(node, "shared_name");
|
||||
auto data_queue = DataQueueMgr::GetInstance().GetDataQueue(queue_name);
|
||||
if (data_queue) {
|
||||
auto tdt_queue = dynamic_cast<device::AscendTdtQueue *>(data_queue->Queue().get());
|
||||
if (tdt_queue) {
|
||||
return tdt_queue->GetWingMan();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return DataQueueMgr::GetInstance().GetDataQueue(queue_name);
|
||||
}
|
||||
|
||||
void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
||||
|
@ -670,10 +670,6 @@ void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node) {
|
|||
namespace {
|
||||
std::shared_ptr<DataQueue> CreateAscendDataQueue(const std::string &channel_name, bool dynamic_shape, size_t capacity,
|
||||
const std::vector<size_t> &) {
|
||||
if (dynamic_shape) {
|
||||
return std::make_shared<AscendDataQueueDynamic>(channel_name, capacity);
|
||||
}
|
||||
|
||||
int32_t is_heterogeneous = 0;
|
||||
(void)rtGetIsHeterogenous(&is_heterogeneous);
|
||||
if (is_heterogeneous != 0) {
|
||||
|
|
|
@ -83,7 +83,6 @@ class AscendTdtQueue : public DataQueue {
|
|||
DataQueueStatus Push(std::vector<DataQueueItem> data) override;
|
||||
DataQueueStatus Front(std::vector<DataQueueItem> *data) const override { return DataQueueStatus::SUCCESS; }
|
||||
DataQueueStatus Pop() override { return DataQueueStatus::SUCCESS; }
|
||||
std::shared_ptr<BlockingQueue> GetWingMan() { return wingman_queue_; }
|
||||
|
||||
private:
|
||||
void DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true) const;
|
||||
|
@ -91,7 +90,6 @@ class AscendTdtQueue : public DataQueue {
|
|||
void ParseType(aclDataType acl_data_type, std::string *data_type) const;
|
||||
bool Translate(const std::vector<DataQueueItem> &data, acltdtDataset **output_acl_dataset) const;
|
||||
|
||||
std::shared_ptr<BlockingQueue> wingman_queue_;
|
||||
acltdtChannelHandle *acl_handle_;
|
||||
uint32_t device_id_;
|
||||
};
|
||||
|
|
|
@ -191,7 +191,8 @@ void DynamicAicpuOpKernelMod::SyncData() {
|
|||
MS_LOG(EXCEPTION) << "The cnode is not dynamic shape:" << cnode->fullname_with_scope();
|
||||
}
|
||||
|
||||
if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE) {
|
||||
if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE ||
|
||||
common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
|
||||
MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " update op skip.";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -49,6 +49,10 @@ std::shared_ptr<DataQueue> DataQueueMgr::CreateDataQueue(const std::string &devi
|
|||
return iter->second(channel_name, dynamic_shape, capacity, shape);
|
||||
}
|
||||
|
||||
void DataQueueMgr::Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue) {
|
||||
(void)name_queue_map_.insert(std::make_pair(channel_name, queue));
|
||||
}
|
||||
|
||||
DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector<size_t> &shape,
|
||||
const size_t capacity) {
|
||||
MS_LOG(INFO) << "Data queue: " << channel_name << " created";
|
||||
|
@ -181,7 +185,7 @@ bool DataQueueMgr::IsInit() const { return init_; }
|
|||
|
||||
bool DataQueueMgr::IsClosed() const { return closed_; }
|
||||
|
||||
inline bool DataQueueMgr::isCreated(const std::string &channel_name) const {
|
||||
bool DataQueueMgr::IsCreated(const std::string &channel_name) const {
|
||||
return name_queue_map_.find(channel_name) != name_queue_map_.end();
|
||||
}
|
||||
|
||||
|
|
|
@ -133,6 +133,9 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
|||
if network.get_inputs() and None not in network.get_inputs():
|
||||
_check_inputs(network, dataset_shapes, dataset_types)
|
||||
min_shapes, max_shapes = None, None
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
dataset_shapes = tuple([(-2,)] * len(dataset_shapes))
|
||||
min_shapes, max_shapes = None, None
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
||||
queue_name, min_shapes, max_shapes)
|
||||
return network
|
||||
|
|
Loading…
Reference in New Issue