Add multiprocessing support for Mindspore.Dataset.GeneratorDataset

This commit is contained in:
Junhan Hu 2020-04-20 17:26:23 -04:00
parent fb18671b28
commit 78001ac9e6
2 changed files with 245 additions and 2 deletions

View File

@ -25,6 +25,7 @@ import os
import random
import uuid
import multiprocessing
import queue
from enum import Enum
from importlib import import_module
@ -2124,6 +2125,142 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple([np.array(x) for x in val])
def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler
"""
indices = sampler.get_indices()
return _sampler_fn_mp(indices, dataset, num_worker)
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
"""
Multiprocessing generator function wrapper for mappable dataset with python sampler
"""
indices = _fetch_py_sampler_indices(sampler, num_samples)
return _sampler_fn_mp(indices, dataset, num_worker)
def _fetch_py_sampler_indices(sampler, num_samples):
"""
Indices fetcher for python sampler
"""
if num_samples is not None:
sampler_iter = iter(sampler)
ret = []
for _ in range(num_samples):
try:
val = next(sampler_iter)
ret.append(val)
except StopIteration:
break
return ret
return [i for i in sampler]
def _fill_worker_indices(workers, indices, idx):
"""
Worker index queue filler, fill worker index queue in round robin order
"""
num_worker = len(workers)
while idx < len(indices):
try:
workers[idx % num_worker].put(indices[idx])
idx += 1
except queue.Full:
break
return idx
def _sampler_fn_mp(indices, dataset, num_worker):
"""
Multiprocessing generator function wrapper master process
"""
workers = []
# Event for end of epoch
eoe = multiprocessing.Event()
# Create workers
for _ in range(num_worker):
worker = _GeneratorWorker(dataset, eoe)
worker.daemon = True
workers.append(worker)
# Fill initial index queues
idx_cursor = 0
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
# Start all workers
for w in workers:
w.start()
# Fetch results
for i in range(len(indices)):
# Fetch result and put index
try:
result = workers[i % num_worker].get()
except queue.Empty:
raise Exception("Generator worker process timeout")
except KeyboardInterrupt:
for w in workers:
w.terminate()
w.join()
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
# Set eoe event once all indices are sent
if idx_cursor == len(indices) and not eoe.is_set():
eoe.set()
yield tuple([np.array(x) for x in result])
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
"""
Multiprocessing generator worker process loop
"""
while True:
# Fetch index, block
try:
idx = idx_queue.get()
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
try:
result_queue.put(result)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx
class _GeneratorWorker(multiprocessing.Process):
"""
Worker process for multiprocess Generator
"""
def __init__(self, dataset, eoe):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe))
def put(self, item):
"""
Put function for worker index queue. Never block. Raise queue.Full on failure.
"""
self.idx_queue.put_nowait(item)
def get(self):
"""
Get function for worker result queue. Block with timeout.
"""
return self.res_queue.get(timeout=5)
class GeneratorDataset(SourceDataset):
"""
A source dataset that generate data from python by invoking python data source each epoch.
@ -2171,6 +2308,7 @@ class GeneratorDataset(SourceDataset):
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
@ -2229,7 +2367,13 @@ class GeneratorDataset(SourceDataset):
sampler_instance.set_num_rows(len(source))
sampler_instance.set_num_samples(num_samples)
sampler_instance.initialize()
if num_parallel_workers > 1:
self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
else:
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
else:
if num_parallel_workers > 1:
self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers))
else:
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
else:

View File

@ -391,6 +391,80 @@ def test_case_13():
i = i + 1
def test_case_14():
"""
Test 1D Generator MP + CPP sampler
"""
logger.info("Test 1D Generator MP : 0 - 63")
source = [(np.array([x]),) for x in range(256)]
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2)
i = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(data["data"], golden)
i = i + 1
if i == 256:
i = 0
def test_case_15():
"""
Test 1D Generator MP + Python sampler
"""
logger.info("Test 1D Generator MP : 0 - 63")
sampler = [x for x in range(256)]
source = [(np.array([x]),) for x in range(256)]
ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2)
i = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(data["data"], golden)
i = i + 1
if i == 256:
i = 0
def test_case_16():
"""
Test multi column generator Mp + CPP sampler
"""
logger.info("Test multi column generator")
source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
# apply dataset operations
data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler())
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["col0"], golden)
golden = np.array([i + 1])
assert np.array_equal(item["col1"], golden)
i = i + 1
def test_case_17():
"""
Test multi column generator Mp + Python sampler
"""
logger.info("Test multi column generator")
sampler = [x for x in range(256)]
source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
# apply dataset operations
data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["col0"], golden)
golden = np.array([i + 1])
assert np.array_equal(item["col1"], golden)
i = i + 1
def test_case_error_1():
def generator_np():
for i in range(64):
@ -506,6 +580,25 @@ def test_num_samples_underflow():
count = count + 1
assert count == 64
def manual_test_keyborad_interrupt():
"""
Test keyborad_interrupt
"""
logger.info("Test 1D Generator MP : 0 - 63")
class MyDS():
def __getitem__(self, item):
while True:
pass
def __len__(self):
return 1024
ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
i = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
pass
if __name__ == "__main__":
test_case_0()
@ -522,6 +615,10 @@ if __name__ == "__main__":
test_case_11()
test_case_12()
test_case_13()
test_case_14()
test_case_15()
test_case_16()
test_case_17()
test_case_error_1()
test_case_error_2()
test_case_error_3()
@ -529,3 +626,5 @@ if __name__ == "__main__":
test_sequential_sampler()
test_distributed_sampler()
test_random_sampler()