forked from mindspore-Ecosystem/mindspore
Add python multiprocessing support for Mindspore.dataset
This commit is contained in:
parent
822a3160e4
commit
b13e7bc31a
|
@ -24,6 +24,7 @@ import math
|
|||
import os
|
||||
import random
|
||||
import uuid
|
||||
import multiprocessing
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
|
||||
|
@ -231,7 +232,7 @@ class Dataset:
|
|||
|
||||
@check_map
|
||||
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||
num_parallel_workers=None):
|
||||
num_parallel_workers=None, python_multiprocessing=False):
|
||||
"""
|
||||
Applies each operation in operations to this dataset.
|
||||
|
||||
|
@ -270,6 +271,8 @@ class Dataset:
|
|||
same).
|
||||
num_parallel_workers (int, optional): Number of threads used to process the dataset in
|
||||
parallel (default=None, the value from the config will be used).
|
||||
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
||||
option could be beneficial if the python operation is computational heavy (default=False).
|
||||
|
||||
Returns:
|
||||
MapDataset, dataset after mapping operation.
|
||||
|
@ -383,7 +386,8 @@ class Dataset:
|
|||
>>> columns_order = ["mod7", "mod3", "col1"]
|
||||
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
|
||||
"""
|
||||
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers)
|
||||
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
|
||||
python_multiprocessing)
|
||||
|
||||
@check_repeat
|
||||
def repeat(self, count=None):
|
||||
|
@ -1041,6 +1045,55 @@ class ShuffleDataset(DatasetOp):
|
|||
return args
|
||||
|
||||
|
||||
# Pyfunc collection for multiprocess pyfunc
|
||||
# This global variable will only be used within subprocesses
|
||||
_GLOBAL_PYFUNC_LIST = []
|
||||
|
||||
|
||||
# Pyfunc worker init function
|
||||
# Python multiprocessing library forbid sending lambda function through pipe.
|
||||
# This init function allow us to add all python function to a global collection and then fork afterwards.
|
||||
def _pyfunc_worker_init(pyfunc_list):
|
||||
global _GLOBAL_PYFUNC_LIST
|
||||
_GLOBAL_PYFUNC_LIST = pyfunc_list
|
||||
|
||||
|
||||
# Pyfunc worker execution function
|
||||
# All exceptions will be raised to main processes
|
||||
def _pyfunc_worker_exec(index, *args):
|
||||
try:
|
||||
return _GLOBAL_PYFUNC_LIST[index](*args)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
|
||||
|
||||
|
||||
# PythonCallable wrapper for multiprocess pyfunc
|
||||
class _PythonCallable:
|
||||
"""
|
||||
Internal python function wrapper for multiprocessing pyfunc
|
||||
"""
|
||||
def __init__(self, py_callable, idx, pool=None):
|
||||
# Original python callable from user.
|
||||
self.py_callable = py_callable
|
||||
# Process pool created for current iterator.
|
||||
self.pool = pool
|
||||
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
|
||||
self.idx = idx
|
||||
|
||||
def __call__(self, *args):
|
||||
if self.pool is not None:
|
||||
try:
|
||||
# This call will send the tensors along with Python callable index to the process pool.
|
||||
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
|
||||
return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args])
|
||||
except KeyboardInterrupt:
|
||||
self.pool.terminate()
|
||||
self.pool.join()
|
||||
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
|
||||
# Invoke original python callable in master process in case the pool is gone.
|
||||
return self.py_callable(*args)
|
||||
|
||||
|
||||
class MapDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying Map operator to the input Dataset.
|
||||
|
@ -1060,13 +1113,15 @@ class MapDataset(DatasetOp):
|
|||
The argument is mandatory if len(input_columns) != len(output_columns).
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset
|
||||
in parallel (default=None).
|
||||
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
||||
option could be beneficial if the python operation is computational heavy (default=False).
|
||||
|
||||
Raises:
|
||||
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||
num_parallel_workers=None):
|
||||
num_parallel_workers=None, python_multiprocessing=False):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.input.append(input_dataset)
|
||||
if input_columns is not None and not isinstance(input_columns, list):
|
||||
|
@ -1087,6 +1142,8 @@ class MapDataset(DatasetOp):
|
|||
|
||||
input_dataset.output.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
self.python_multiprocessing = python_multiprocessing
|
||||
self.process_pool = None
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -1104,6 +1161,40 @@ class MapDataset(DatasetOp):
|
|||
"""
|
||||
return self.input[0].get_dataset_size()
|
||||
|
||||
# Iterator bootstrap will be called on iterator construction.
|
||||
# A deep copy of Dataset object is created prior of iterator_bootstrap.
|
||||
# This method will create per iterator process pool and bind pyfunc execution to the pool.
|
||||
def iterator_bootstrap(self):
|
||||
"""
|
||||
Per iterator bootstrap callback.
|
||||
"""
|
||||
if self.python_multiprocessing:
|
||||
iter_specific_operations = []
|
||||
callable_list = []
|
||||
|
||||
# Pass #1, look for python callables and build list
|
||||
for op in self.operations:
|
||||
if callable(op):
|
||||
callable_list.append(op)
|
||||
|
||||
if callable_list:
|
||||
# Construct pool with the callable list
|
||||
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
|
||||
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
|
||||
initializer=_pyfunc_worker_init,
|
||||
initargs=(callable_list,))
|
||||
# Pass #2
|
||||
idx = 0
|
||||
for op in self.operations:
|
||||
if callable(op):
|
||||
# Wrap python callable into _PythonCallable
|
||||
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
|
||||
idx += 1
|
||||
else:
|
||||
# CPP ops remain the same
|
||||
iter_specific_operations.append(op)
|
||||
self.operations = iter_specific_operations
|
||||
|
||||
|
||||
class RepeatDataset(DatasetOp):
|
||||
"""
|
||||
|
|
|
@ -63,6 +63,10 @@ def _alter_node(node):
|
|||
return new_shuffle
|
||||
|
||||
if isinstance(node, de.MapDataset):
|
||||
if node.python_multiprocessing:
|
||||
# Bootstrap can only be performed on a copy of the original dataset node.
|
||||
# Bootstrap on original dataset node will make all iterators share the same process pool
|
||||
node.iterator_bootstrap()
|
||||
if node.columns_order is not None:
|
||||
# Remove the connection between the parent's node to the current node because we are inserting a node.
|
||||
if node.output:
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
@ -181,6 +182,106 @@ def test_case_6():
|
|||
i = i + 4
|
||||
|
||||
|
||||
def test_case_7():
|
||||
"""
|
||||
Test PyFunc
|
||||
"""
|
||||
logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x),
|
||||
num_parallel_workers=4, python_multiprocessing = True)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# In this test, the dataset is 2x2 sequential tensors
|
||||
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
|
||||
assert np.array_equal(item["out"], golden)
|
||||
i = i + 4
|
||||
|
||||
|
||||
def test_case_8():
|
||||
"""
|
||||
Test PyFunc
|
||||
"""
|
||||
logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
|
||||
|
||||
col = ["col0", "col1"]
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
|
||||
operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"],
|
||||
python_multiprocessing=True)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# In this test, the dataset is 2x2 sequential tensors
|
||||
golden = np.array([[i, i + 1], [i + 2, i + 3]])
|
||||
assert np.array_equal(item["out0"], golden)
|
||||
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
|
||||
assert np.array_equal(item["out1"], golden)
|
||||
golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
|
||||
assert np.array_equal(item["out2"], golden)
|
||||
i = i + 4
|
||||
|
||||
|
||||
def test_case_9():
|
||||
"""
|
||||
Test PyFunc
|
||||
"""
|
||||
logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x")
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
|
||||
data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1),
|
||||
(lambda x: x + 2)],
|
||||
num_parallel_workers=4, python_multiprocessing=True)
|
||||
|
||||
i = 0
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
# In this test, the dataset is 2x2 sequential tensors
|
||||
golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]])
|
||||
assert np.array_equal(item["out"], golden)
|
||||
i = i + 4
|
||||
|
||||
|
||||
def test_pyfunc_execption():
|
||||
logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")
|
||||
|
||||
def pyfunc(x):
|
||||
raise Exception("Pyfunc Throw")
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
|
||||
num_parallel_workers=4)
|
||||
for _ in data1:
|
||||
pass
|
||||
assert "Pyfunc Throw" in str(info.value)
|
||||
|
||||
|
||||
def test_pyfunc_execption_multiprocess():
|
||||
logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()")
|
||||
|
||||
def pyfunc(x):
|
||||
raise Exception("MP Pyfunc Throw")
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
# apply dataset operations
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
||||
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
|
||||
num_parallel_workers=4, python_multiprocessing = True)
|
||||
for _ in data1:
|
||||
pass
|
||||
assert "MP Pyfunc Throw" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_case_0()
|
||||
test_case_1()
|
||||
|
@ -189,3 +290,8 @@ if __name__ == "__main__":
|
|||
test_case_4()
|
||||
test_case_5()
|
||||
test_case_6()
|
||||
test_case_7()
|
||||
test_case_8()
|
||||
test_case_9()
|
||||
test_pyfunc_execption()
|
||||
test_pyfunc_execption_multiprocess()
|
||||
|
|
Loading…
Reference in New Issue