diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 5aa26ab3024..1a29b7b932b 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -202,7 +202,7 @@ def check_value_type(arg_name, arg_value, valid_types): if not is_valid: raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' - f'bug got {type(arg_value).__name__}.') + f'but got {type(arg_value).__name__}.') def read_proto(file_name, proto_format="MINDIR"): diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index a583b170830..4e91308214e 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -15,6 +15,7 @@ """Write events to disk in a base directory.""" import os import time +import signal from collections import deque import mindspore.log as logger @@ -77,6 +78,10 @@ class WriterPool(ctx.Process): os.environ['GOTO_NUM_THREADS'] = '2' os.environ['OMP_NUM_THREADS'] = '2' + # Prevent the multiprocess from capturing KeyboardInterrupt, + # which causes the main process to fail to exit. + signal.signal(signal.SIGINT, signal.SIG_IGN) + with ctx.Pool(min(ctx.cpu_count(), 32)) as pool: deq = deque() while True: diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index 3349cf8287b..2513a07ad79 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -118,7 +118,7 @@ class TestSummaryCollector: with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \ - f'bug got {type(collect_freq).__name__}.' + f'but got {type(collect_freq).__name__}.' assert expected_msg == str(exc.value) @pytest.mark.parametrize("action", [None, 123, '', '123']) @@ -128,7 +128,7 @@ class TestSummaryCollector: with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir=summary_dir, keep_default_action=action) expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \ - f"bug got {type(action).__name__}." + f"but got {type(action).__name__}." assert expected_msg == str(exc.value) @pytest.mark.parametrize("collect_specified_data", [123]) @@ -139,7 +139,7 @@ class TestSummaryCollector: SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \ - f"bug got {type(collect_specified_data).__name__}." + f"but got {type(collect_specified_data).__name__}." assert expected_msg == str(exc.value) @@ -159,7 +159,7 @@ class TestSummaryCollector: param_name = list(collect_specified_data)[0] expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \ - f"bug got {type(param_name).__name__}." + f"but got {type(param_name).__name__}." assert expected_msg == str(exc.value) @pytest.mark.parametrize("collect_specified_data", [ @@ -183,7 +183,7 @@ class TestSummaryCollector: param_value = collect_specified_data[param_name] expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']" expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \ - f'bug got {type(param_value).__name__}.' + f'but got {type(param_value).__name__}.' assert expected_msg == str(exc.value) @@ -216,18 +216,18 @@ class TestSummaryCollector: if not isinstance(custom_lineage_data, dict): expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \ - f"bug got {type(custom_lineage_data).__name__}." + f"but got {type(custom_lineage_data).__name__}." else: param_name = list(custom_lineage_data)[0] param_value = custom_lineage_data[param_name] if not isinstance(param_name, str): arg_name = f'custom_lineage_data -> {param_name}' expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \ - f'bug got {type(param_name).__name__}.' + f'but got {type(param_name).__name__}.' else: arg_name = f'the value of custom_lineage_data -> {param_name}' expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \ - f'bug got {type(param_value).__name__}.' + f'but got {type(param_value).__name__}.' assert expected_msg == str(exc.value)