diff --git a/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h b/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h index 143c371c5ef..faebdf47bfa 100644 --- a/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h +++ b/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h @@ -105,6 +105,8 @@ class BACKEND_EXPORT DataQueueMgr { bool IsClosed() const; + bool IsCreated(const std::string &channel_name) const; + bool Destroy(); // call for Release GPU Resources @@ -117,9 +119,9 @@ class BACKEND_EXPORT DataQueueMgr { size_t Capacity(const std::string &channel_name); - private: - inline bool isCreated(const std::string &channel_name) const; + void Manage(const std::string &channel_name, const std::shared_ptr &queue); + private: DataQueueMgr(const DataQueueMgr &) = delete; DataQueueMgr &operator=(const DataQueueMgr &) = delete; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc index 660980e5328..7e6160a8a07 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc @@ -77,8 +77,8 @@ DataQueueOp::DataQueueOp(const std::string channel_name, DeviceType device_type, #ifdef WITH_BACKEND MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); if (MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { - (void)device::DataQueueMgr::GetInstance().Create(channel_name, {}, 0); - ascend_data_queue_ = device::DataQueueMgr::GetInstance().GetDataQueue(channel_name)->Queue(); + ascend_data_queue_ = + device::DataQueueMgr::GetInstance().CreateDataQueue(kAscendDevice, channel_name, dynamic_shape_, 0, {}); } #endif #ifdef ENABLE_DUMP_IR diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.cc index 28037d22795..26a1768e1d6 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.cc @@ -200,12 +200,15 @@ AscendTdtQueue::AscendTdtQueue(const std::string &channel_name) : DataQueue(chan } tdt_handle::AddHandle(&acl_handle_, nullptr); } - wingman_queue_ = std::make_shared(); + + // a wingman of tdt to help with transferring data shapes on host + auto wingman_queue = std::make_shared(); std::shared_ptr data_queue = std::make_shared(channel_name); - auto rt = wingman_queue_->Create(data_queue); + auto rt = wingman_queue->Create(data_queue); if (rt != DataQueueStatus::SUCCESS) { MS_LOG(EXCEPTION) << "Wingman queue: " << channel_name << "create failed: " << rt; } + DataQueueMgr::GetInstance().Manage(channel_name, wingman_queue); } AscendTdtQueue::~AscendTdtQueue() { @@ -218,6 +221,9 @@ AscendTdtQueue::~AscendTdtQueue() { acl_handle_ = nullptr; } } + if (DataQueueMgr::GetInstance().IsCreated(channel_name_)) { + DataQueueMgr::GetInstance().Free(channel_name_); + } } bool AscendTdtQueue::IsOpen() const { return !tdt_handle::IsClosed(); } @@ -257,8 +263,9 @@ DataQueueStatus AscendTdtQueue::Push(std::vector data) { MS_LOG(EXCEPTION) << "Tdt Send data failed. The details refer to 'Ascend Error Message'." << ascend::GetErrorMessage(true); } - if (wingman_queue_->IsOpen() && !data.empty()) { - wingman_queue_->Push(data); + auto wingman = DataQueueMgr::GetInstance().GetDataQueue(channel_name_); + if (wingman != nullptr && wingman->IsOpen() && !data.empty()) { + wingman->Push(data); } return DataQueueStatus::SUCCESS; } @@ -649,14 +656,7 @@ void WingmanQueue::Close() { std::shared_ptr GetTdtWingManQueue(const std::shared_ptr &node) { if (common::AnfAlgo::GetCNodeName(node) != kGetNextOpName) return nullptr; auto queue_name = common::AnfAlgo::GetNodeAttr(node, "shared_name"); - auto data_queue = DataQueueMgr::GetInstance().GetDataQueue(queue_name); - if (data_queue) { - auto tdt_queue = dynamic_cast(data_queue->Queue().get()); - if (tdt_queue) { - return tdt_queue->GetWingMan(); - } - } - return nullptr; + return DataQueueMgr::GetInstance().GetDataQueue(queue_name); } void CloseTdtWingManQueue(const std::shared_ptr &node) { @@ -670,10 +670,6 @@ void CloseTdtWingManQueue(const std::shared_ptr &node) { namespace { std::shared_ptr CreateAscendDataQueue(const std::string &channel_name, bool dynamic_shape, size_t capacity, const std::vector &) { - if (dynamic_shape) { - return std::make_shared(channel_name, capacity); - } - int32_t is_heterogeneous = 0; (void)rtGetIsHeterogenous(&is_heterogeneous); if (is_heterogeneous != 0) { diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.h index e9c56bd58ee..a002d29ed68 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_data_queue.h @@ -83,7 +83,6 @@ class AscendTdtQueue : public DataQueue { DataQueueStatus Push(std::vector data) override; DataQueueStatus Front(std::vector *data) const override { return DataQueueStatus::SUCCESS; } DataQueueStatus Pop() override { return DataQueueStatus::SUCCESS; } - std::shared_ptr GetWingMan() { return wingman_queue_; } private: void DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true) const; @@ -91,7 +90,6 @@ class AscendTdtQueue : public DataQueue { void ParseType(aclDataType acl_data_type, std::string *data_type) const; bool Translate(const std::vector &data, acltdtDataset **output_acl_dataset) const; - std::shared_ptr wingman_queue_; acltdtChannelHandle *acl_handle_; uint32_t device_id_; }; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/dynamic_aicpu_kernel_mod.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/dynamic_aicpu_kernel_mod.cc index 31670e17ce7..c943a79e70d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/dynamic_aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/dynamic_aicpu_kernel_mod.cc @@ -191,7 +191,8 @@ void DynamicAicpuOpKernelMod::SyncData() { MS_LOG(EXCEPTION) << "The cnode is not dynamic shape:" << cnode->fullname_with_scope(); } - if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE) { + if (unknow_type_ != device::ascend::UnknowShapeOpType::DEPEND_COMPUTE || + common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " update op skip."; return; } diff --git a/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc b/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc index 4e0831cf36b..944c7a433ca 100644 --- a/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc +++ b/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc @@ -49,6 +49,10 @@ std::shared_ptr DataQueueMgr::CreateDataQueue(const std::string &devi return iter->second(channel_name, dynamic_shape, capacity, shape); } +void DataQueueMgr::Manage(const std::string &channel_name, const std::shared_ptr &queue) { + (void)name_queue_map_.insert(std::make_pair(channel_name, queue)); +} + DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector &shape, const size_t capacity) { MS_LOG(INFO) << "Data queue: " << channel_name << " created"; @@ -181,7 +185,7 @@ bool DataQueueMgr::IsInit() const { return init_; } bool DataQueueMgr::IsClosed() const { return closed_; } -inline bool DataQueueMgr::isCreated(const std::string &channel_name) const { +bool DataQueueMgr::IsCreated(const std::string &channel_name) const { return name_queue_map_.find(channel_name) != name_queue_map_.end(); } diff --git a/mindspore/python/mindspore/train/dataset_helper.py b/mindspore/python/mindspore/train/dataset_helper.py index 82bfddab4aa..f7f7f60afeb 100644 --- a/mindspore/python/mindspore/train/dataset_helper.py +++ b/mindspore/python/mindspore/train/dataset_helper.py @@ -133,6 +133,9 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name): if network.get_inputs() and None not in network.get_inputs(): _check_inputs(network, dataset_shapes, dataset_types) min_shapes, max_shapes = None, None + elif context.get_context("mode") == context.PYNATIVE_MODE: + dataset_shapes = tuple([(-2,)] * len(dataset_shapes)) + min_shapes, max_shapes = None, None network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name, min_shapes, max_shapes) return network