Prevent the multiprocess from capturing KeyboardInterrupt

This commit is contained in:
ougongchang 2020-12-29 16:30:31 +08:00
parent 94a9ff7719
commit e5529230bf
3 changed files with 14 additions and 9 deletions

View File

@ -202,7 +202,7 @@ def check_value_type(arg_name, arg_value, valid_types):
if not is_valid:
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
f'bug got {type(arg_value).__name__}.')
f'but got {type(arg_value).__name__}.')
def read_proto(file_name, proto_format="MINDIR"):

View File

@ -15,6 +15,7 @@
"""Write events to disk in a base directory."""
import os
import time
import signal
from collections import deque
import mindspore.log as logger
@ -77,6 +78,10 @@ class WriterPool(ctx.Process):
os.environ['GOTO_NUM_THREADS'] = '2'
os.environ['OMP_NUM_THREADS'] = '2'
# Prevent the multiprocess from capturing KeyboardInterrupt,
# which causes the main process to fail to exit.
signal.signal(signal.SIGINT, signal.SIG_IGN)
with ctx.Pool(min(ctx.cpu_count(), 32)) as pool:
deq = deque()
while True:

View File

@ -118,7 +118,7 @@ class TestSummaryCollector:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
f'bug got {type(collect_freq).__name__}.'
f'but got {type(collect_freq).__name__}.'
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("action", [None, 123, '', '123'])
@ -128,7 +128,7 @@ class TestSummaryCollector:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
f"bug got {type(action).__name__}."
f"but got {type(action).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [123])
@ -139,7 +139,7 @@ class TestSummaryCollector:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(collect_specified_data).__name__}."
f"but got {type(collect_specified_data).__name__}."
assert expected_msg == str(exc.value)
@ -159,7 +159,7 @@ class TestSummaryCollector:
param_name = list(collect_specified_data)[0]
expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
f"bug got {type(param_name).__name__}."
f"but got {type(param_name).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [
@ -183,7 +183,7 @@ class TestSummaryCollector:
param_value = collect_specified_data[param_name]
expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'bug got {type(param_value).__name__}.'
f'but got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
@ -216,18 +216,18 @@ class TestSummaryCollector:
if not isinstance(custom_lineage_data, dict):
expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(custom_lineage_data).__name__}."
f"but got {type(custom_lineage_data).__name__}."
else:
param_name = list(custom_lineage_data)[0]
param_value = custom_lineage_data[param_name]
if not isinstance(param_name, str):
arg_name = f'custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
f'bug got {type(param_name).__name__}.'
f'but got {type(param_name).__name__}.'
else:
arg_name = f'the value of custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
f'bug got {type(param_value).__name__}.'
f'but got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)