diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7a8e261d9d..53a9fc73c1 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3256,14 +3256,13 @@ class SamplerFn: # Event for end of epoch if multi_process is True: self.eoe = multiprocessing.Event() - self.eof = multiprocessing.Event() else: self.eoe = threading.Event() self.eof = threading.Event() # Create workers for _ in range(num_worker): if multi_process is True: - worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) + worker = _GeneratorWorkerMp(dataset, self.eoe) else: worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) worker.daemon = True @@ -3304,15 +3303,40 @@ class SamplerFn: def __del__(self): self.eoe.set() - self.eof.set() if self.multi_process is False: + self.eof.set() for w in self.workers: w.join() -def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): +def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): """ - Multiprocessing or multithread generator worker process loop. + Multiprocessing generator worker process loop + """ + while True: + # Fetch index, block + try: + idx = idx_queue.get() + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + if idx is None: + # When the queue is out of scope from master process, a None item can be fetched from the queue. + # Upon receiving None, worker process should check if EOE is set. + assert eoe.is_set(), "" + return + # Fetch data, any exception from __getitem__ will terminate worker and timeout master process + result = dataset[idx] + # Send data, block + try: + result_queue.put(result) + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + del result, idx + + +def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): + """ + Multithread generator worker process loop. """ while True: # Fetch index, block @@ -3360,7 +3384,7 @@ class _GeneratorWorkerMt(threading.Thread): def __init__(self, dataset, eoe, eof): self.idx_queue = queue.Queue(16) self.res_queue = queue.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) def put(self, item): """ @@ -3380,10 +3404,10 @@ class _GeneratorWorkerMp(multiprocessing.Process): Worker process for multiprocess Generator. """ - def __init__(self, dataset, eoe, eof): + def __init__(self, dataset, eoe): self.idx_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) def put(self, item): """