fix: dataset map multiprocessing without use thread mode

This commit is contained in:
jonyguo 2023-02-03 14:36:43 +08:00
parent 40abe83fc4
commit 4d41a380ae
4 changed files with 91 additions and 15 deletions

View File

@ -275,6 +275,11 @@ Status MapOp::WorkerCompute(const TensorRow &in_row, TensorRow *out_row,
*out_row = TensorRow(TensorRow::kFlagError); *out_row = TensorRow(TensorRow::kFlagError);
return Status::OK(); return Status::OK();
} else { } else {
// if thread had been interrupted, don't care the error
if (TaskManager::FindMe()->Interrupted()) {
MS_LOG(WARNING) << "Current thread had been interrupted by TaskManager, so ignore the error.";
return Status::OK();
}
return rc; return rc;
} }
} }

View File

@ -65,7 +65,13 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
if (output_type_ != DataType::DE_UNKNOWN) { if (output_type_ != DataType::DE_UNKNOWN) {
RETURN_IF_NOT_OK(CastOutput(ret_py_obj, output)); RETURN_IF_NOT_OK(CastOutput(ret_py_obj, output));
} else { } else {
if (py::isinstance<py::tuple>(ret_py_obj)) { // scenario 1: map multi-processing, subprocess stop first and will get none
// scenario 2: thread mode, user pyfunc return none
if (ret_py_obj.is_none()) {
MS_LOG(INFO) << "Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
"True, it maybe due to pyfunc execution timeout.";
goto TimeoutError;
} else if (py::isinstance<py::tuple>(ret_py_obj)) {
// In case of a n-m mapping, the return value will be a tuple of numpy arrays // In case of a n-m mapping, the return value will be a tuple of numpy arrays
auto ret_py_tuple = ret_py_obj.cast<py::tuple>(); auto ret_py_tuple = ret_py_obj.cast<py::tuple>();
// Iterate over two containers simultaneously for memory copy // Iterate over two containers simultaneously for memory copy
@ -73,8 +79,8 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
py::object ret_py_ele = ret_py_tuple[i]; py::object ret_py_ele = ret_py_tuple[i];
// Object is none if pyfunc timeout // Object is none if pyfunc timeout
if (ret_py_ele.is_none()) { if (ret_py_ele.is_none()) {
MS_LOG(INFO) << "Expected that PyFunc should return numpy array, got None. If python_multiprocessing is " MS_LOG(INFO) << "Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
"True, PyFunc may execute time out."; "True, it maybe due to pyfunc execution timeout.";
goto TimeoutError; goto TimeoutError;
} }
RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_ele, output)); RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_ele, output));
@ -94,8 +100,8 @@ ComputeReturn:
TimeoutError: TimeoutError:
ret = STATUS_ERROR(StatusCode::kMDTimeOut, ret = STATUS_ERROR(StatusCode::kMDTimeOut,
"Expected that PyFunc should return numpy array, got None. If \'python_multiprocessing\' is True, " "Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
"PyFunc may execute time out."); "True, it maybe due to pyfunc execution timeout.");
goto ComputeReturn; goto ComputeReturn;
} }

View File

@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \ check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_enable_watchdog, get_seed, set_seed, get_debug_mode get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval
from ..core.datatypes import mstype_to_detype from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none from ..core.validator_helpers import replace_none
from ..core.py_util_helpers import ExceptionHandler from ..core.py_util_helpers import ExceptionHandler
@ -2729,17 +2729,38 @@ class _PythonCallable:
self.pool = pool self.pool = pool
# Python callable index # Python callable index
self.idx = idx self.idx = idx
self.check_interval = get_multiprocessing_timeout_interval()
def __call__(self, *args): def __call__(self, *args):
result = None result = None
if self.pool.is_running() and check_iterator_cleanup() is False: start_time = time.time()
try: count = 1
result = self.pool.execute(self.idx, *args) get_data_from_worker_process = False
except multiprocessing.TimeoutError: while get_data_from_worker_process is False:
pass cost_time = time.time() - start_time
if result is None: if cost_time > (self.check_interval * count):
# Invoke original Python callable in master process in case the pool is gone. logger.warning("It has been waiting for " + str(cost_time) + "s because the multi "
result = self.py_callable(*args) "workers of map operation cost long time to process next data. "
"Worker process list are: " + str(self.pool.get_pids()) + ", you can use "
"\"py-spy dump -p {PID} -l -s \""
"to dump the worker process stack. You can also set the timeout interval by "
"ds.config.set_multiprocessing_interval to adjust the output frequency of this "
"log.")
count += 1
if self.pool.is_running() and check_iterator_cleanup() is False:
try:
result = self.pool.execute(self.idx, *args)
except multiprocessing.TimeoutError:
continue
get_data_from_worker_process = True
else:
# worker process is stopped
logger.warning("The worker process of map operation is stopped. "
"So return None to main thread and break the main thread.")
return None
# got value from worker process
if not isinstance(result, tuple) and get_data_from_worker_process is True:
result = (result,)
return result return result
def to_json(self): def to_json(self):
@ -3123,7 +3144,6 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
atexit.register(self.terminate) atexit.register(self.terminate)
def terminate(self): def terminate(self):
logger.info("Terminating Python Multiprocessing for Op:" + str(self.op_id))
self.close_all_workers() self.close_all_workers()
self.abort_watchdog() self.abort_watchdog()

View File

@ -420,6 +420,50 @@ def test_map_just_exchange_columns():
assert item[2].shape == (300, 300, 3) assert item[2].shape == (300, 300, 3)
class FakeData:
def __init__(self):
self.input_ids = np.ones((128, 128), dtype=np.int32)
self.input_mask = np.ones((128, 128), dtype=np.int32)
def __getitem__(self, index):
return self.input_ids, self.input_mask
def __len__(self):
return 791
def test_map_multiprocessing_without_thread():
"""
Feature: Map op
Description: map with multiprocessing and don't degenerate into threading
Expectation: success
"""
dataset = ds.GeneratorDataset(FakeData(), ["input_ids", "input_mask"])
def long_running_op(col1, col2):
data1 = np.ones([50, 3, 655, 655], dtype=np.float64)
data2 = np.ones([50, 3, 600, 600], dtype=np.float64)
return data1, data2
dataset = dataset.map(operations=long_running_op, input_columns=["input_ids", "input_mask"],
python_multiprocessing=True, num_parallel_workers=2, max_rowsize=10)
assert dataset.get_dataset_size() == 791
assert dataset.output_shapes() == [[50, 3, 655, 655], [50, 3, 600, 600]]
assert dataset.output_types() == [np.float64, np.float64]
assert dataset.get_col_names() == ["input_ids", "input_mask"]
count = 1
for item in dataset.create_tuple_iterator(output_numpy=True, num_epochs=1):
print("count: {}, type: {}, shape: {}".format(count, item[0].dtype, item[0].shape))
assert item[0].dtype == np.float64
assert item[0].shape == (50, 3, 655, 655)
assert len(item) == 2
count += 1
if count > 5:
break
if __name__ == '__main__': if __name__ == '__main__':
test_map_c_transform_exception() test_map_c_transform_exception()
test_map_py_transform_exception() test_map_py_transform_exception()
@ -431,3 +475,4 @@ if __name__ == '__main__':
test_python_map_mp_seed_repeatability() test_python_map_mp_seed_repeatability()
test_map_with_deprecated_parameter() test_map_with_deprecated_parameter()
test_map_just_exchange_columns() test_map_just_exchange_columns()
test_map_multiprocessing_without_thread()