forked from mindspore-Ecosystem/mindspore
Add multiprocessing support for Mindspore.Dataset.GeneratorDataset
This commit is contained in:
parent
fb18671b28
commit
78001ac9e6
|
@ -25,6 +25,7 @@ import os
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import queue
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
@ -2124,6 +2125,142 @@ def _cpp_sampler_fn(sampler, dataset):
|
||||||
yield tuple([np.array(x) for x in val])
|
yield tuple([np.array(x) for x in val])
|
||||||
|
|
||||||
|
|
||||||
|
def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
|
||||||
|
"""
|
||||||
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler
|
||||||
|
"""
|
||||||
|
indices = sampler.get_indices()
|
||||||
|
return _sampler_fn_mp(indices, dataset, num_worker)
|
||||||
|
|
||||||
|
|
||||||
|
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
|
||||||
|
"""
|
||||||
|
Multiprocessing generator function wrapper for mappable dataset with python sampler
|
||||||
|
"""
|
||||||
|
indices = _fetch_py_sampler_indices(sampler, num_samples)
|
||||||
|
return _sampler_fn_mp(indices, dataset, num_worker)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_py_sampler_indices(sampler, num_samples):
|
||||||
|
"""
|
||||||
|
Indices fetcher for python sampler
|
||||||
|
"""
|
||||||
|
if num_samples is not None:
|
||||||
|
sampler_iter = iter(sampler)
|
||||||
|
ret = []
|
||||||
|
for _ in range(num_samples):
|
||||||
|
try:
|
||||||
|
val = next(sampler_iter)
|
||||||
|
ret.append(val)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
return ret
|
||||||
|
return [i for i in sampler]
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_worker_indices(workers, indices, idx):
|
||||||
|
"""
|
||||||
|
Worker index queue filler, fill worker index queue in round robin order
|
||||||
|
"""
|
||||||
|
num_worker = len(workers)
|
||||||
|
while idx < len(indices):
|
||||||
|
try:
|
||||||
|
workers[idx % num_worker].put(indices[idx])
|
||||||
|
idx += 1
|
||||||
|
except queue.Full:
|
||||||
|
break
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
def _sampler_fn_mp(indices, dataset, num_worker):
|
||||||
|
"""
|
||||||
|
Multiprocessing generator function wrapper master process
|
||||||
|
"""
|
||||||
|
workers = []
|
||||||
|
# Event for end of epoch
|
||||||
|
eoe = multiprocessing.Event()
|
||||||
|
|
||||||
|
# Create workers
|
||||||
|
for _ in range(num_worker):
|
||||||
|
worker = _GeneratorWorker(dataset, eoe)
|
||||||
|
worker.daemon = True
|
||||||
|
workers.append(worker)
|
||||||
|
|
||||||
|
# Fill initial index queues
|
||||||
|
idx_cursor = 0
|
||||||
|
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
|
||||||
|
|
||||||
|
# Start all workers
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
# Fetch results
|
||||||
|
for i in range(len(indices)):
|
||||||
|
# Fetch result and put index
|
||||||
|
try:
|
||||||
|
result = workers[i % num_worker].get()
|
||||||
|
except queue.Empty:
|
||||||
|
raise Exception("Generator worker process timeout")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
for w in workers:
|
||||||
|
w.terminate()
|
||||||
|
w.join()
|
||||||
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
|
if idx_cursor < len(indices):
|
||||||
|
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
|
||||||
|
# Set eoe event once all indices are sent
|
||||||
|
if idx_cursor == len(indices) and not eoe.is_set():
|
||||||
|
eoe.set()
|
||||||
|
yield tuple([np.array(x) for x in result])
|
||||||
|
|
||||||
|
|
||||||
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
|
||||||
|
"""
|
||||||
|
Multiprocessing generator worker process loop
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
# Fetch index, block
|
||||||
|
try:
|
||||||
|
idx = idx_queue.get()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
|
if idx is None:
|
||||||
|
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||||
|
# Upon receiving None, worker process should check if EOE is set.
|
||||||
|
assert eoe.is_set(), ""
|
||||||
|
return
|
||||||
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||||
|
result = dataset[idx]
|
||||||
|
# Send data, block
|
||||||
|
try:
|
||||||
|
result_queue.put(result)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
|
del result, idx
|
||||||
|
|
||||||
|
|
||||||
|
class _GeneratorWorker(multiprocessing.Process):
|
||||||
|
"""
|
||||||
|
Worker process for multiprocess Generator
|
||||||
|
"""
|
||||||
|
def __init__(self, dataset, eoe):
|
||||||
|
self.idx_queue = multiprocessing.Queue(16)
|
||||||
|
self.res_queue = multiprocessing.Queue(16)
|
||||||
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe))
|
||||||
|
|
||||||
|
def put(self, item):
|
||||||
|
"""
|
||||||
|
Put function for worker index queue. Never block. Raise queue.Full on failure.
|
||||||
|
"""
|
||||||
|
self.idx_queue.put_nowait(item)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Get function for worker result queue. Block with timeout.
|
||||||
|
"""
|
||||||
|
return self.res_queue.get(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorDataset(SourceDataset):
|
class GeneratorDataset(SourceDataset):
|
||||||
"""
|
"""
|
||||||
A source dataset that generate data from python by invoking python data source each epoch.
|
A source dataset that generate data from python by invoking python data source each epoch.
|
||||||
|
@ -2171,6 +2308,7 @@ class GeneratorDataset(SourceDataset):
|
||||||
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
|
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
|
||||||
num_samples (int, optional): The number of samples to be included in the dataset
|
num_samples (int, optional): The number of samples to be included in the dataset
|
||||||
(default=None, all images).
|
(default=None, all images).
|
||||||
|
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
||||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
|
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
|
||||||
(default=None, expected order behavior shown in the table).
|
(default=None, expected order behavior shown in the table).
|
||||||
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
|
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
|
||||||
|
@ -2229,7 +2367,13 @@ class GeneratorDataset(SourceDataset):
|
||||||
sampler_instance.set_num_rows(len(source))
|
sampler_instance.set_num_rows(len(source))
|
||||||
sampler_instance.set_num_samples(num_samples)
|
sampler_instance.set_num_samples(num_samples)
|
||||||
sampler_instance.initialize()
|
sampler_instance.initialize()
|
||||||
|
if num_parallel_workers > 1:
|
||||||
|
self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
|
||||||
|
else:
|
||||||
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
|
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
|
||||||
|
else:
|
||||||
|
if num_parallel_workers > 1:
|
||||||
|
self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers))
|
||||||
else:
|
else:
|
||||||
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
|
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -391,6 +391,80 @@ def test_case_13():
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_14():
|
||||||
|
"""
|
||||||
|
Test 1D Generator MP + CPP sampler
|
||||||
|
"""
|
||||||
|
logger.info("Test 1D Generator MP : 0 - 63")
|
||||||
|
|
||||||
|
source = [(np.array([x]),) for x in range(256)]
|
||||||
|
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2)
|
||||||
|
i = 0
|
||||||
|
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([i])
|
||||||
|
assert np.array_equal(data["data"], golden)
|
||||||
|
i = i + 1
|
||||||
|
if i == 256:
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_15():
|
||||||
|
"""
|
||||||
|
Test 1D Generator MP + Python sampler
|
||||||
|
"""
|
||||||
|
logger.info("Test 1D Generator MP : 0 - 63")
|
||||||
|
|
||||||
|
sampler = [x for x in range(256)]
|
||||||
|
source = [(np.array([x]),) for x in range(256)]
|
||||||
|
ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2)
|
||||||
|
i = 0
|
||||||
|
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([i])
|
||||||
|
assert np.array_equal(data["data"], golden)
|
||||||
|
i = i + 1
|
||||||
|
if i == 256:
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_16():
|
||||||
|
"""
|
||||||
|
Test multi column generator Mp + CPP sampler
|
||||||
|
"""
|
||||||
|
logger.info("Test multi column generator")
|
||||||
|
|
||||||
|
source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
|
||||||
|
# apply dataset operations
|
||||||
|
data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler())
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([i])
|
||||||
|
assert np.array_equal(item["col0"], golden)
|
||||||
|
golden = np.array([i + 1])
|
||||||
|
assert np.array_equal(item["col1"], golden)
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_17():
|
||||||
|
"""
|
||||||
|
Test multi column generator Mp + Python sampler
|
||||||
|
"""
|
||||||
|
logger.info("Test multi column generator")
|
||||||
|
|
||||||
|
sampler = [x for x in range(256)]
|
||||||
|
source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
|
||||||
|
# apply dataset operations
|
||||||
|
data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([i])
|
||||||
|
assert np.array_equal(item["col0"], golden)
|
||||||
|
golden = np.array([i + 1])
|
||||||
|
assert np.array_equal(item["col1"], golden)
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
|
||||||
def test_case_error_1():
|
def test_case_error_1():
|
||||||
def generator_np():
|
def generator_np():
|
||||||
for i in range(64):
|
for i in range(64):
|
||||||
|
@ -506,6 +580,25 @@ def test_num_samples_underflow():
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 64
|
assert count == 64
|
||||||
|
|
||||||
|
def manual_test_keyborad_interrupt():
|
||||||
|
"""
|
||||||
|
Test keyborad_interrupt
|
||||||
|
"""
|
||||||
|
logger.info("Test 1D Generator MP : 0 - 63")
|
||||||
|
|
||||||
|
class MyDS():
|
||||||
|
def __getitem__(self, item):
|
||||||
|
while True:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2)
|
||||||
|
i = 0
|
||||||
|
for data in ds1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_case_0()
|
test_case_0()
|
||||||
|
@ -522,6 +615,10 @@ if __name__ == "__main__":
|
||||||
test_case_11()
|
test_case_11()
|
||||||
test_case_12()
|
test_case_12()
|
||||||
test_case_13()
|
test_case_13()
|
||||||
|
test_case_14()
|
||||||
|
test_case_15()
|
||||||
|
test_case_16()
|
||||||
|
test_case_17()
|
||||||
test_case_error_1()
|
test_case_error_1()
|
||||||
test_case_error_2()
|
test_case_error_2()
|
||||||
test_case_error_3()
|
test_case_error_3()
|
||||||
|
@ -529,3 +626,5 @@ if __name__ == "__main__":
|
||||||
test_sequential_sampler()
|
test_sequential_sampler()
|
||||||
test_distributed_sampler()
|
test_distributed_sampler()
|
||||||
test_random_sampler()
|
test_random_sampler()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue