From 7821b2355f3b6d6c2cf9582f6d880c24b3eced64 Mon Sep 17 00:00:00 2001 From: changzherui Date: Tue, 22 Feb 2022 23:47:50 +0800 Subject: [PATCH] add ckpt check and mod cn api --- docs/api/api_python/mindspore/mindspore.save_checkpoint.rst | 1 + .../train/mindspore.train.callback.LossMonitor.rst | 3 ++- mindspore/python/mindspore/train/serialization.py | 5 ++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/api/api_python/mindspore/mindspore.save_checkpoint.rst b/docs/api/api_python/mindspore/mindspore.save_checkpoint.rst index 1845a34a18e..72fd74204fc 100644 --- a/docs/api/api_python/mindspore/mindspore.save_checkpoint.rst +++ b/docs/api/api_python/mindspore/mindspore.save_checkpoint.rst @@ -18,3 +18,4 @@ mindspore.save_checkpoint **异常:** - **TypeError** – 如果参数 `save_obj` 类型不为nn.Cell或者list,且如果参数 `integrated_save` 及 `async_save` 非bool类型。 + - **TypeError** – 如果参数 `ckpt_file_name` 不是str类型。 diff --git a/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst b/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst index c9a0bd7790a..001ced26eba 100644 --- a/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst +++ b/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst @@ -10,11 +10,12 @@ **参数:** - **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值:1。 + - **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如何设置了该参数,LossMonitor将监控该数值之后epoch的loss值。默认值:0。 **异常:** - **ValueError** - 当 `per_print_times` 不是整数或小于零。 - + - **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。 .. py:method:: step_end(run_context) diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index bdbc621c459..19415cb9bb3 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -244,6 +244,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, if not isinstance(ckpt_file_name, str): raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be " "string, but got {}.".format(ckpt_file_name, type(ckpt_file_name))) + ckpt_file_name = os.path.realpath(ckpt_file_name) + if os.path.isdir(ckpt_file_name): + raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, " + "it should be a file name.".format(ckpt_file_name)) if not ckpt_file_name.endswith('.ckpt'): ckpt_file_name += ".ckpt" integrated_save = Validator.check_bool(integrated_save) @@ -298,7 +302,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, data = param["data"].asnumpy().reshape(-1) data_list[key].append(data) - ckpt_file_name = os.path.realpath(ckpt_file_name) if async_save: data_copy = copy.deepcopy(data_list) thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")