!43932 add docs for callback

Merge pull request !43932 from changzherui/code_docs_1014
This commit is contained in:
i-robot 2022-10-17 07:57:46 +00:00 committed by Gitee
commit 51ae189eec
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 83 additions and 26 deletions

View File

@ -15,6 +15,7 @@ mindspore.train
.. mscnautosummary::
:toctree: train
mindspore.train.BackupAndRestore
mindspore.train.Callback
mindspore.train.CheckpointConfig
mindspore.train.EarlyStopping

View File

@ -0,0 +1,40 @@
mindspore.train.BackupAndRestore
================================
.. py:class:: mindspore.train.BackupAndRestore(backup_dir, save_freq="epoch", delete_checkpoint=True)
在训练过程中备份和恢复训练参数的回调函数。
.. note::
只能在训练过程使用这个方法。
参数:
- **backup_dir** (str) - 保存和恢复checkpoint文件的路径。
- **save_freq** (Union['epoch', int]) - 当设置为'epoch'时在每个epoch进行备份当设置为整数时将在每隔 `save_freq` 个epoch进行备份。默认值'epoch'。
- **delete_checkpoint** (bool) - 如果 `delete_checkpoint=True` 将在训练结束的时候删除备份文件否则保留备份文件。默认值True。
异常:
- **ValueError** - 如果 `backup_dir` 参数不是str类型。
- **ValueError** - 如果 `save_freq` 参数不是'epoch'或str类型。
- **ValueError** - 如果 `delete_checkpoint` 参数不是bool类型。
.. py:method:: on_train_begin(run_context)
在训练开始时加载备份的checkpoint文件。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_epoch_end(run_context)
在每个epoch结束时判断是否需要备份checkpoint文件。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_end(run_context)
在训练结束时判断是否删除备份的checkpoint文件。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`

View File

@ -18,123 +18,123 @@ mindspore.train.Callback
在网络执行之前被调用一次。与 `on_train_begin``on_eval_begin` 方法具有兼容性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: end(run_context)
网络执行后被调用一次。与 `on_train_end``on_eval_end` 方法具有兼容性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: epoch_begin(run_context)
在每个epoch开始之前被调用。与 `on_train_epoch_begin``on_eval_epoch_begin` 方法具有兼容性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: epoch_end(run_context)
在每个epoch结束后被调用。与 `on_train_epoch_end``on_eval_epoch_end` 方法具有兼容性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_eval_begin(run_context)
在网络执行推理之前调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_eval_end(run_context)
网络执行推理之后调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_eval_epoch_begin(run_context)
在推理的epoch开始之前被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_eval_epoch_end(run_context)
在推理的epoch结束后被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_eval_step_begin(run_context)
在推理的每个step开始之前被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_eval_step_end(run_context)
在推理的每个step完成后被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_begin(run_context)
在网络执行训练之前调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_end(run_context)
网络训练执行结束时调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_epoch_begin(run_context)
在训练的每个epoch开始之前被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_epoch_end(run_context)
在训练的每个epoch结束后被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_step_begin(run_context)
在训练的每个step开始之前被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_step_end(run_context)
在训练的每个step完成后被调用。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: step_begin(run_context)
在每个step开始之前被调用。与 `on_train_step_begin``on_eval_step_begin` 方法具有兼容性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: step_end(run_context)
在每个step完成后被调用。与 `on_train_step_end``on_eval_step_end` 方法具有兼容性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`

View File

@ -13,4 +13,4 @@ mindspore.train.LearningRateScheduler
在step结束时更改学习率。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`

View File

@ -26,7 +26,7 @@ mindspore.train.ModelCheckpoint
在训练结束后会保存最后一个step的checkpoint。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: latest_ckpt_file_name
:property:
@ -38,4 +38,4 @@ mindspore.train.ModelCheckpoint
在step结束时保存checkpoint。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`

View File

@ -19,6 +19,7 @@ Callback
:nosignatures:
:template: classtemplate.rst
mindspore.train.BackupAndRestore
mindspore.train.Callback
mindspore.train.CheckpointConfig
mindspore.train.EarlyStopping

View File

@ -32,6 +32,14 @@ class JitConfig:
task_sink (bool): Determines whether to pass the data through dataset channel. Default: True.
**kwargs (dict): A dictionary of keyword arguments that the class needs.
Examples:
>>> from mindspore.common.jit_config import JitConfig
>>>
>>> jitconfig = JitConfig(jit_level="O1")
>>> net = LeNet5()
>>>
>>> net.set_jit_config(jitconfig)
"""
def __init__(self, jit_level="O1", task_sink=True, **kwargs):
if jit_level not in ["O0", "O1", "O2"]:

View File

@ -73,7 +73,7 @@ class BackupAndRestore(Callback):
Load the backup checkpoint file at the beginning of epoch.
Args:
run_context (RunContext): Context of the process running. For more details,
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
if os.path.exists(self.backup_file):
@ -84,10 +84,10 @@ class BackupAndRestore(Callback):
def on_train_epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
Backup checkpoint file at the end of train epoch.
Args:
run_context (RunContext): Context of the process running. For more details,
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
@ -98,6 +98,13 @@ class BackupAndRestore(Callback):
save_checkpoint(train_net, self.backup_file)
def on_train_end(self, run_context):
"""
Deleted checkpoint file at the end of train.
Args:
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
run_context.original_args()
cb_params = run_context.original_args()
cur_epoch_num = cb_params.cur_epoch_num

View File

@ -57,7 +57,7 @@ class TimeMonitor(Callback):
Record time at the beginning of epoch.
Args:
run_context (RunContext): Context of the process running. For more details,
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
self.epoch_time = time.time()
@ -67,7 +67,7 @@ class TimeMonitor(Callback):
Print process cost time at the end of epoch.
Args:
run_context (RunContext): Context of the process running. For more details,
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000