From f63a13b2e90ffe63ad8cc9181efeffb6472ad7ee Mon Sep 17 00:00:00 2001 From: changzherui Date: Fri, 23 Apr 2021 08:57:43 +0800 Subject: [PATCH] modify check file name bug --- mindspore/_checkparam.py | 5 +++++ mindspore/train/callback/_checkpoint.py | 20 +++-------------- tests/ut/python/utils/test_callback.py | 29 ++----------------------- 3 files changed, 10 insertions(+), 44 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index e8dfe680f45..142391b60be 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -443,12 +443,17 @@ class Validator: @staticmethod def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): """Check whether file name is legitimate.""" + if not isinstance(target, str): + raise ValueError("Args file_name {} must be string, please check it".format(target)) + if target.endswith("\\") or target.endswith("/"): + raise ValueError("File name cannot be a directory path.") if reg is None: reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" if re.match(reg, target, flag) is None: prim_name = f'in `{prim_name}`' if prim_name else "" raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( target, prim_name, reg, flag)) + return True @staticmethod diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index a8c36c5b6ad..da8528b6593 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -26,6 +26,7 @@ from mindspore._checkparam import Validator from mindspore.train._utils import _make_directory from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank +from mindspore.parallel._cell_wrapper import destroy_allgather_cell from ._callback import Callback, set_cur_net from ...common.tensor import Tensor @@ -33,17 +34,6 @@ _cur_dir = os.getcwd() _save_dir = _cur_dir -def _check_file_name_prefix(file_name_prefix): - """ - Check file name valid or not. - - File name can't include '/'. This file name naming convention only apply to Linux. - """ - if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0: - return False - return True - - def _chg_ckpt_file_name_if_same_exist(directory, prefix): """Check if there is a file with the same name.""" files = os.listdir(directory) @@ -245,11 +235,8 @@ class ModelCheckpoint(Callback): self._last_time_for_keep = time.time() self._last_triggered_step = 0 - if _check_file_name_prefix(prefix): - self._prefix = prefix - else: - raise ValueError("Prefix {} for checkpoint file name invalid, " - "please check and correct it and then continue.".format(prefix)) + Validator.check_file_name_by_regular(prefix) + self._prefix = prefix if directory is not None: self._directory = _make_directory(directory) @@ -310,7 +297,6 @@ class ModelCheckpoint(Callback): if thread.getName() == "asyn_save_ckpt": thread.join() - from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell() def _check_save_ckpt(self, cb_params, force_to_save): diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index 0fc79a5eab5..4a90d962e42 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -28,7 +28,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op -from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist +from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist + class Net(nn.Cell): """Net definition.""" @@ -150,32 +151,6 @@ def test_loss_monitor_normal_mode(): loss_cb.end(run_context) -def test_check_file_name_not_str(): - """Test check file name not str.""" - ret = _check_file_name_prefix(1) - assert not ret - - -def test_check_file_name_back_err(): - """Test check file name back err.""" - ret = _check_file_name_prefix('abc.') - assert ret - - -def test_check_file_name_one_alpha(): - """Test check file name one alpha.""" - ret = _check_file_name_prefix('a') - assert ret - ret = _check_file_name_prefix('_') - assert ret - - -def test_check_file_name_err(): - """Test check file name err.""" - ret = _check_file_name_prefix('_123') - assert ret - - def test_chg_ckpt_file_name_if_same_exist(): """Test chg ckpt file name if same exist.""" _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")