From 0bb80995bc43639401acad1758dc43303586eab8 Mon Sep 17 00:00:00 2001 From: jiangshuqiang <962978787@qq.com> Date: Mon, 25 Jan 2021 11:05:50 +0800 Subject: [PATCH] fix param check for unexpected_format --- mindspore/train/_utils.py | 2 +- mindspore/train/summary/_writer_pool.py | 2 +- mindspore/train/summary/summary_record.py | 14 ++++++++++--- .../train/summary/test_summary_collector.py | 20 +++++++++++++++++++ 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 398e66636a5..2b55911f9fe 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -88,7 +88,7 @@ def _make_directory(path: str): else: logger.debug("The directory(%s) doesn't exist, will create it", path) try: - os.makedirs(path, exist_ok=True) + os.makedirs(path, exist_ok=True, mode=0o700) real_path = path except PermissionError as e: logger.error("No write permission on the directory(%r), error = %r", path, e) diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index f09a47b2f7f..3a40eeefa26 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -138,7 +138,7 @@ class WriterPool(ctx.Process): for writer in self._writers[:]: try: writer.write(plugin, data) - except RuntimeError as exc: + except (RuntimeError, OSError) as exc: logger.error(str(exc)) self._writers.remove(writer) writer.close() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 8a914e10992..3c29b414315 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -36,7 +36,7 @@ _summary_lock = threading.Lock() # cache the summary data _summary_tensor_cache = {} _DEFAULT_EXPORT_OPTIONS = { - 'tensor_format': 'npy', + 'tensor_format': {'npy'}, } @@ -68,14 +68,22 @@ def process_export_options(export_options): check_value_type('export_options', export_options, [dict, type(None)]) - for param_name in export_options: - check_value_type(param_name, param_name, [str]) + for export_option, export_format in export_options.items(): + check_value_type('export_option', export_option, [str]) + check_value_type('export_format', export_format, [str]) unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS) if unexpected_params: raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, ' f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}') + for export_option, export_format in export_options.items(): + unexpected_format = {export_format} - _DEFAULT_EXPORT_OPTIONS.get(export_option) + if unexpected_format: + raise ValueError( + f'For `export_options`, the export_format {unexpected_format} are unsupported for {export_option}, ' + f'expect the follow values: {list(_DEFAULT_EXPORT_OPTIONS.get(export_option))}') + for item in set(export_options): check_value_type(item, export_options.get(item), [str, type(None)]) diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index 02644ee4212..96011f1f731 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -28,11 +28,13 @@ from mindspore.train.callback import SummaryCollector from mindspore.train.callback import _InternalCallbackParam from mindspore.train.summary.enums import ModeEnum, PluginEnum from mindspore.train.summary import SummaryRecord +from mindspore.train.summary.summary_record import _DEFAULT_EXPORT_OPTIONS from mindspore.nn import Cell from mindspore.nn.optim.optimizer import Optimizer from mindspore.ops.operations import Add + _VALUE_CACHE = list() @@ -143,6 +145,24 @@ class TestSummaryCollector: assert expected_msg == str(exc.value) + @pytest.mark.parametrize("export_options", [ + { + "tensor_format": "npz" + } + ]) + def test_params_with_tensor_format_type_error(self, export_options): + """Test type error scenario for collect specified data param.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(ValueError) as exc: + SummaryCollector(summary_dir, export_options=export_options) + + unexpected_format = {export_options.get("tensor_format")} + expected_msg = f'For `export_options`, the export_format {unexpected_format} are ' \ + f'unsupported for tensor_format, expect the follow values: ' \ + f'{list(_DEFAULT_EXPORT_OPTIONS.get("tensor_format"))}' + + assert expected_msg == str(exc.value) + @pytest.mark.parametrize("export_options", [123]) def test_params_with_export_options_type_error(self, export_options): """Test type error scenario for collect specified data param."""