modify check file name bug

This commit is contained in:
changzherui 2021-04-23 08:57:43 +08:00
parent 668d079942
commit f63a13b2e9
3 changed files with 10 additions and 44 deletions

View File

@ -443,12 +443,17 @@ class Validator:
@staticmethod
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
"""Check whether file name is legitimate."""
if not isinstance(target, str):
raise ValueError("Args file_name {} must be string, please check it".format(target))
if target.endswith("\\") or target.endswith("/"):
raise ValueError("File name cannot be a directory path.")
if reg is None:
reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
target, prim_name, reg, flag))
return True
@staticmethod

View File

@ -26,6 +26,7 @@ from mindspore._checkparam import Validator
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
from ._callback import Callback, set_cur_net
from ...common.tensor import Tensor
@ -33,17 +34,6 @@ _cur_dir = os.getcwd()
_save_dir = _cur_dir
def _check_file_name_prefix(file_name_prefix):
"""
Check file name valid or not.
File name can't include '/'. This file name naming convention only apply to Linux.
"""
if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
return False
return True
def _chg_ckpt_file_name_if_same_exist(directory, prefix):
"""Check if there is a file with the same name."""
files = os.listdir(directory)
@ -245,11 +235,8 @@ class ModelCheckpoint(Callback):
self._last_time_for_keep = time.time()
self._last_triggered_step = 0
if _check_file_name_prefix(prefix):
self._prefix = prefix
else:
raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix))
Validator.check_file_name_by_regular(prefix)
self._prefix = prefix
if directory is not None:
self._directory = _make_directory(directory)
@ -310,7 +297,6 @@ class ModelCheckpoint(Callback):
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save):

View File

@ -28,7 +28,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist
class Net(nn.Cell):
"""Net definition."""
@ -150,32 +151,6 @@ def test_loss_monitor_normal_mode():
loss_cb.end(run_context)
def test_check_file_name_not_str():
"""Test check file name not str."""
ret = _check_file_name_prefix(1)
assert not ret
def test_check_file_name_back_err():
"""Test check file name back err."""
ret = _check_file_name_prefix('abc.')
assert ret
def test_check_file_name_one_alpha():
"""Test check file name one alpha."""
ret = _check_file_name_prefix('a')
assert ret
ret = _check_file_name_prefix('_')
assert ret
def test_check_file_name_err():
"""Test check file name err."""
ret = _check_file_name_prefix('_123')
assert ret
def test_chg_ckpt_file_name_if_same_exist():
"""Test chg ckpt file name if same exist."""
_chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")