forked from mindspore-Ecosystem/mindspore
!1480 gpu iiterator weak ref support
Merge pull request !1480 from panfengfeng/iterator_gpu_weak_ref
This commit is contained in:
commit
e7936dedeb
|
@ -26,10 +26,6 @@
|
|||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "tdt/tsd_client.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
||||
|
@ -167,9 +163,15 @@ Status DeviceQueueOp::SendDataToGPU() {
|
|||
is_break_loop = true;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
if (!TaskManager::FindMe()->Interrupted())
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
else
|
||||
is_break_loop = true;
|
||||
}
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
if (!TaskManager::FindMe()->Interrupted())
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
else
|
||||
is_break_loop = true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
|
||||
|
@ -191,7 +193,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
|
|||
items.push_back(data_item);
|
||||
}
|
||||
|
||||
while (!GpuBufferMgr::GetInstance().IsClosed()) {
|
||||
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
|
||||
RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row));
|
||||
auto ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
||||
if (ret) {
|
||||
|
|
|
@ -172,9 +172,7 @@ bool GpuBufferMgr::CloseNotify() {
|
|||
{
|
||||
std::lock_guard<std::mutex> lk(close_mutex_);
|
||||
// set closed_ to be true, all the dataset retry can be jumped out of the while
|
||||
closed_ = true; // set closed_ to be true, all the dataset retry can be jumped out of the while
|
||||
// notify all the waiting dataset threads
|
||||
close_confirm_cond_.notify_all(); // notify all the waiting dataset threads
|
||||
closed_ = true;
|
||||
}
|
||||
|
||||
// wati for the dataset threads' ack
|
||||
|
@ -188,16 +186,6 @@ bool GpuBufferMgr::CloseNotify() {
|
|||
return result;
|
||||
}
|
||||
|
||||
void GpuBufferMgr::CloseConfirm() {
|
||||
// lock scope
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(close_mutex_);
|
||||
// dataset threads wait for the closed_ flag from false to true
|
||||
close_confirm_cond_.wait(
|
||||
lk, [this] { return closed_; }); // dataset threads wait for the closed_ flag from false to true
|
||||
}
|
||||
|
||||
sema.Signal();
|
||||
}
|
||||
void GpuBufferMgr::CloseConfirm() { sema.Signal(); }
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -119,7 +119,6 @@ class GpuBufferMgr {
|
|||
bool closed_;
|
||||
std::mutex mutex_;
|
||||
std::mutex close_mutex_;
|
||||
std::condition_variable close_confirm_cond_;
|
||||
// how many queues opened by dataset
|
||||
int open_by_dataset_;
|
||||
Semaphore sema;
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
from importlib import import_module
|
||||
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
|
@ -25,10 +24,6 @@ from mindspore._c_dataengine import OpName
|
|||
from mindspore import log as logger
|
||||
from . import datasets as de
|
||||
|
||||
try:
|
||||
context = import_module("mindspore.context")
|
||||
except ModuleNotFoundError:
|
||||
context = None
|
||||
|
||||
ITERATORS_LIST = list()
|
||||
|
||||
|
@ -36,18 +31,9 @@ ITERATORS_LIST = list()
|
|||
def _cleanup():
|
||||
"""Release all the Iterator."""
|
||||
for itr_ref in ITERATORS_LIST:
|
||||
if context:
|
||||
device_type = context.get_context("device_target")
|
||||
if device_type == "GPU":
|
||||
itr_ref.release()
|
||||
else:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
else:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
|
@ -101,14 +87,7 @@ class Iterator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
if context:
|
||||
device_type = context.get_context("device_target")
|
||||
if device_type == "GPU":
|
||||
ITERATORS_LIST.append(self)
|
||||
else:
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
else:
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
|
|
Loading…
Reference in New Issue