forked from mindspore-Ecosystem/mindspore
modify check file name bug
This commit is contained in:
parent
668d079942
commit
f63a13b2e9
|
@ -443,12 +443,17 @@ class Validator:
|
|||
@staticmethod
|
||||
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
||||
"""Check whether file name is legitimate."""
|
||||
if not isinstance(target, str):
|
||||
raise ValueError("Args file_name {} must be string, please check it".format(target))
|
||||
if target.endswith("\\") or target.endswith("/"):
|
||||
raise ValueError("File name cannot be a directory path.")
|
||||
if reg is None:
|
||||
reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
|
||||
if re.match(reg, target, flag) is None:
|
||||
prim_name = f'in `{prim_name}`' if prim_name else ""
|
||||
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
|
||||
target, prim_name, reg, flag))
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore._checkparam import Validator
|
|||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
from ._callback import Callback, set_cur_net
|
||||
from ...common.tensor import Tensor
|
||||
|
||||
|
@ -33,17 +34,6 @@ _cur_dir = os.getcwd()
|
|||
_save_dir = _cur_dir
|
||||
|
||||
|
||||
def _check_file_name_prefix(file_name_prefix):
|
||||
"""
|
||||
Check file name valid or not.
|
||||
|
||||
File name can't include '/'. This file name naming convention only apply to Linux.
|
||||
"""
|
||||
if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _chg_ckpt_file_name_if_same_exist(directory, prefix):
|
||||
"""Check if there is a file with the same name."""
|
||||
files = os.listdir(directory)
|
||||
|
@ -245,11 +235,8 @@ class ModelCheckpoint(Callback):
|
|||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = 0
|
||||
|
||||
if _check_file_name_prefix(prefix):
|
||||
self._prefix = prefix
|
||||
else:
|
||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
||||
"please check and correct it and then continue.".format(prefix))
|
||||
Validator.check_file_name_by_regular(prefix)
|
||||
self._prefix = prefix
|
||||
|
||||
if directory is not None:
|
||||
self._directory = _make_directory(directory)
|
||||
|
@ -310,7 +297,6 @@ class ModelCheckpoint(Callback):
|
|||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
||||
|
|
|
@ -28,7 +28,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
|
||||
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
|
||||
from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition."""
|
||||
|
@ -150,32 +151,6 @@ def test_loss_monitor_normal_mode():
|
|||
loss_cb.end(run_context)
|
||||
|
||||
|
||||
def test_check_file_name_not_str():
|
||||
"""Test check file name not str."""
|
||||
ret = _check_file_name_prefix(1)
|
||||
assert not ret
|
||||
|
||||
|
||||
def test_check_file_name_back_err():
|
||||
"""Test check file name back err."""
|
||||
ret = _check_file_name_prefix('abc.')
|
||||
assert ret
|
||||
|
||||
|
||||
def test_check_file_name_one_alpha():
|
||||
"""Test check file name one alpha."""
|
||||
ret = _check_file_name_prefix('a')
|
||||
assert ret
|
||||
ret = _check_file_name_prefix('_')
|
||||
assert ret
|
||||
|
||||
|
||||
def test_check_file_name_err():
|
||||
"""Test check file name err."""
|
||||
ret = _check_file_name_prefix('_123')
|
||||
assert ret
|
||||
|
||||
|
||||
def test_chg_ckpt_file_name_if_same_exist():
|
||||
"""Test chg ckpt file name if same exist."""
|
||||
_chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")
|
||||
|
|
Loading…
Reference in New Issue