forked from mindspore-Ecosystem/mindspore
fix: dataset map multiprocessing without use thread mode
This commit is contained in:
parent
40abe83fc4
commit
4d41a380ae
|
@ -275,6 +275,11 @@ Status MapOp::WorkerCompute(const TensorRow &in_row, TensorRow *out_row,
|
|||
*out_row = TensorRow(TensorRow::kFlagError);
|
||||
return Status::OK();
|
||||
} else {
|
||||
// if thread had been interrupted, don't care the error
|
||||
if (TaskManager::FindMe()->Interrupted()) {
|
||||
MS_LOG(WARNING) << "Current thread had been interrupted by TaskManager, so ignore the error.";
|
||||
return Status::OK();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,7 +65,13 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
if (output_type_ != DataType::DE_UNKNOWN) {
|
||||
RETURN_IF_NOT_OK(CastOutput(ret_py_obj, output));
|
||||
} else {
|
||||
if (py::isinstance<py::tuple>(ret_py_obj)) {
|
||||
// scenario 1: map multi-processing, subprocess stop first and will get none
|
||||
// scenario 2: thread mode, user pyfunc return none
|
||||
if (ret_py_obj.is_none()) {
|
||||
MS_LOG(INFO) << "Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
|
||||
"True, it maybe due to pyfunc execution timeout.";
|
||||
goto TimeoutError;
|
||||
} else if (py::isinstance<py::tuple>(ret_py_obj)) {
|
||||
// In case of a n-m mapping, the return value will be a tuple of numpy arrays
|
||||
auto ret_py_tuple = ret_py_obj.cast<py::tuple>();
|
||||
// Iterate over two containers simultaneously for memory copy
|
||||
|
@ -73,8 +79,8 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
|
|||
py::object ret_py_ele = ret_py_tuple[i];
|
||||
// Object is none if pyfunc timeout
|
||||
if (ret_py_ele.is_none()) {
|
||||
MS_LOG(INFO) << "Expected that PyFunc should return numpy array, got None. If python_multiprocessing is "
|
||||
"True, PyFunc may execute time out.";
|
||||
MS_LOG(INFO) << "Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
|
||||
"True, it maybe due to pyfunc execution timeout.";
|
||||
goto TimeoutError;
|
||||
}
|
||||
RETURN_IF_NOT_OK(ConvertNumpyToTensor(ret_py_ele, output));
|
||||
|
@ -94,8 +100,8 @@ ComputeReturn:
|
|||
|
||||
TimeoutError:
|
||||
ret = STATUS_ERROR(StatusCode::kMDTimeOut,
|
||||
"Expected that PyFunc should return numpy array, got None. If \'python_multiprocessing\' is True, "
|
||||
"PyFunc may execute time out.");
|
||||
"Expect pyfunc to return numpy array(s), but got None. If python_multiprocessing is "
|
||||
"True, it maybe due to pyfunc execution timeout.");
|
||||
goto ComputeReturn;
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
|
||||
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_enable_watchdog, get_seed, set_seed, get_debug_mode
|
||||
get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
from ..core.validator_helpers import replace_none
|
||||
from ..core.py_util_helpers import ExceptionHandler
|
||||
|
@ -2729,17 +2729,38 @@ class _PythonCallable:
|
|||
self.pool = pool
|
||||
# Python callable index
|
||||
self.idx = idx
|
||||
self.check_interval = get_multiprocessing_timeout_interval()
|
||||
|
||||
def __call__(self, *args):
|
||||
result = None
|
||||
if self.pool.is_running() and check_iterator_cleanup() is False:
|
||||
try:
|
||||
result = self.pool.execute(self.idx, *args)
|
||||
except multiprocessing.TimeoutError:
|
||||
pass
|
||||
if result is None:
|
||||
# Invoke original Python callable in master process in case the pool is gone.
|
||||
result = self.py_callable(*args)
|
||||
start_time = time.time()
|
||||
count = 1
|
||||
get_data_from_worker_process = False
|
||||
while get_data_from_worker_process is False:
|
||||
cost_time = time.time() - start_time
|
||||
if cost_time > (self.check_interval * count):
|
||||
logger.warning("It has been waiting for " + str(cost_time) + "s because the multi "
|
||||
"workers of map operation cost long time to process next data. "
|
||||
"Worker process list are: " + str(self.pool.get_pids()) + ", you can use "
|
||||
"\"py-spy dump -p {PID} -l -s \""
|
||||
"to dump the worker process stack. You can also set the timeout interval by "
|
||||
"ds.config.set_multiprocessing_interval to adjust the output frequency of this "
|
||||
"log.")
|
||||
count += 1
|
||||
if self.pool.is_running() and check_iterator_cleanup() is False:
|
||||
try:
|
||||
result = self.pool.execute(self.idx, *args)
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
get_data_from_worker_process = True
|
||||
else:
|
||||
# worker process is stopped
|
||||
logger.warning("The worker process of map operation is stopped. "
|
||||
"So return None to main thread and break the main thread.")
|
||||
return None
|
||||
# got value from worker process
|
||||
if not isinstance(result, tuple) and get_data_from_worker_process is True:
|
||||
result = (result,)
|
||||
return result
|
||||
|
||||
def to_json(self):
|
||||
|
@ -3123,7 +3144,6 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
|
|||
atexit.register(self.terminate)
|
||||
|
||||
def terminate(self):
|
||||
logger.info("Terminating Python Multiprocessing for Op:" + str(self.op_id))
|
||||
self.close_all_workers()
|
||||
self.abort_watchdog()
|
||||
|
||||
|
|
|
@ -420,6 +420,50 @@ def test_map_just_exchange_columns():
|
|||
assert item[2].shape == (300, 300, 3)
|
||||
|
||||
|
||||
class FakeData:
|
||||
def __init__(self):
|
||||
self.input_ids = np.ones((128, 128), dtype=np.int32)
|
||||
self.input_mask = np.ones((128, 128), dtype=np.int32)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.input_ids, self.input_mask
|
||||
|
||||
def __len__(self):
|
||||
return 791
|
||||
|
||||
|
||||
def test_map_multiprocessing_without_thread():
|
||||
"""
|
||||
Feature: Map op
|
||||
Description: map with multiprocessing and don't degenerate into threading
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
dataset = ds.GeneratorDataset(FakeData(), ["input_ids", "input_mask"])
|
||||
|
||||
def long_running_op(col1, col2):
|
||||
data1 = np.ones([50, 3, 655, 655], dtype=np.float64)
|
||||
data2 = np.ones([50, 3, 600, 600], dtype=np.float64)
|
||||
return data1, data2
|
||||
|
||||
dataset = dataset.map(operations=long_running_op, input_columns=["input_ids", "input_mask"],
|
||||
python_multiprocessing=True, num_parallel_workers=2, max_rowsize=10)
|
||||
assert dataset.get_dataset_size() == 791
|
||||
assert dataset.output_shapes() == [[50, 3, 655, 655], [50, 3, 600, 600]]
|
||||
assert dataset.output_types() == [np.float64, np.float64]
|
||||
assert dataset.get_col_names() == ["input_ids", "input_mask"]
|
||||
|
||||
count = 1
|
||||
for item in dataset.create_tuple_iterator(output_numpy=True, num_epochs=1):
|
||||
print("count: {}, type: {}, shape: {}".format(count, item[0].dtype, item[0].shape))
|
||||
assert item[0].dtype == np.float64
|
||||
assert item[0].shape == (50, 3, 655, 655)
|
||||
assert len(item) == 2
|
||||
count += 1
|
||||
if count > 5:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_map_c_transform_exception()
|
||||
test_map_py_transform_exception()
|
||||
|
@ -431,3 +475,4 @@ if __name__ == '__main__':
|
|||
test_python_map_mp_seed_repeatability()
|
||||
test_map_with_deprecated_parameter()
|
||||
test_map_just_exchange_columns()
|
||||
test_map_multiprocessing_without_thread()
|
||||
|
|
Loading…
Reference in New Issue