add ckpt check and mod cn api

This commit is contained in:
changzherui 2022-02-22 23:47:50 +08:00
parent 6390b71296
commit 7821b2355f
3 changed files with 7 additions and 2 deletions

View File

@ -18,3 +18,4 @@ mindspore.save_checkpoint
**异常:**
- **TypeError** 如果参数 `save_obj` 类型不为nn.Cell或者list且如果参数 `integrated_save``async_save` 非bool类型。
- **TypeError** 如果参数 `ckpt_file_name` 不是str类型。

View File

@ -10,11 +10,12 @@
**参数:**
- **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值1。
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch如何设置了该参数LossMonitor将监控该数值之后epoch的loss值。默认值0。
**异常:**
- **ValueError** - 当 `per_print_times` 不是整数或小于零。
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
.. py:method:: step_end(run_context)

View File

@ -244,6 +244,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
if not isinstance(ckpt_file_name, str):
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
ckpt_file_name = os.path.realpath(ckpt_file_name)
if os.path.isdir(ckpt_file_name):
raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, "
"it should be a file name.".format(ckpt_file_name))
if not ckpt_file_name.endswith('.ckpt'):
ckpt_file_name += ".ckpt"
integrated_save = Validator.check_bool(integrated_save)
@ -298,7 +302,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
data = param["data"].asnumpy().reshape(-1)
data_list[key].append(data)
ckpt_file_name = os.path.realpath(ckpt_file_name)
if async_save:
data_copy = copy.deepcopy(data_list)
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")