forked from mindspore-Ecosystem/mindspore
add ckpt check and mod cn api
This commit is contained in:
parent
6390b71296
commit
7821b2355f
|
@ -18,3 +18,4 @@ mindspore.save_checkpoint
|
|||
**异常:**
|
||||
|
||||
- **TypeError** – 如果参数 `save_obj` 类型不为nn.Cell或者list,且如果参数 `integrated_save` 及 `async_save` 非bool类型。
|
||||
- **TypeError** – 如果参数 `ckpt_file_name` 不是str类型。
|
||||
|
|
|
@ -10,11 +10,12 @@
|
|||
**参数:**
|
||||
|
||||
- **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值:1。
|
||||
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如何设置了该参数,LossMonitor将监控该数值之后epoch的loss值。默认值:0。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 当 `per_print_times` 不是整数或小于零。
|
||||
|
||||
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
|
||||
|
||||
.. py:method:: step_end(run_context)
|
||||
|
||||
|
|
|
@ -244,6 +244,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
if not isinstance(ckpt_file_name, str):
|
||||
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
|
||||
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
||||
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
||||
if os.path.isdir(ckpt_file_name):
|
||||
raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, "
|
||||
"it should be a file name.".format(ckpt_file_name))
|
||||
if not ckpt_file_name.endswith('.ckpt'):
|
||||
ckpt_file_name += ".ckpt"
|
||||
integrated_save = Validator.check_bool(integrated_save)
|
||||
|
@ -298,7 +302,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
data = param["data"].asnumpy().reshape(-1)
|
||||
data_list[key].append(data)
|
||||
|
||||
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
||||
if async_save:
|
||||
data_copy = copy.deepcopy(data_list)
|
||||
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")
|
||||
|
|
Loading…
Reference in New Issue