forked from mindspore-Ecosystem/mindspore
!43932 add docs for callback
Merge pull request !43932 from changzherui/code_docs_1014
This commit is contained in:
commit
51ae189eec
|
@ -15,6 +15,7 @@ mindspore.train
|
|||
.. mscnautosummary::
|
||||
:toctree: train
|
||||
|
||||
mindspore.train.BackupAndRestore
|
||||
mindspore.train.Callback
|
||||
mindspore.train.CheckpointConfig
|
||||
mindspore.train.EarlyStopping
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
mindspore.train.BackupAndRestore
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.train.BackupAndRestore(backup_dir, save_freq="epoch", delete_checkpoint=True)
|
||||
|
||||
在训练过程中备份和恢复训练参数的回调函数。
|
||||
|
||||
.. note::
|
||||
只能在训练过程使用这个方法。
|
||||
|
||||
参数:
|
||||
- **backup_dir** (str) - 保存和恢复checkpoint文件的路径。
|
||||
- **save_freq** (Union['epoch', int]) - 当设置为'epoch'时,在每个epoch进行备份,当设置为整数时,将在每隔 `save_freq` 个epoch进行备份。默认值:'epoch'。
|
||||
- **delete_checkpoint** (bool) - 如果 `delete_checkpoint=True` ,将在训练结束的时候删除备份文件,否则保留备份文件。默认值:True。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 如果 `backup_dir` 参数不是str类型。
|
||||
- **ValueError** - 如果 `save_freq` 参数不是'epoch'或str类型。
|
||||
- **ValueError** - 如果 `delete_checkpoint` 参数不是bool类型。
|
||||
|
||||
.. py:method:: on_train_begin(run_context)
|
||||
|
||||
在训练开始时,加载备份的checkpoint文件。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
在每个epoch结束时,判断是否需要备份checkpoint文件。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_end(run_context)
|
||||
|
||||
在训练结束时,判断是否删除备份的checkpoint文件。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
|
@ -18,123 +18,123 @@ mindspore.train.Callback
|
|||
在网络执行之前被调用一次。与 `on_train_begin` 和 `on_eval_begin` 方法具有兼容性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: end(run_context)
|
||||
|
||||
网络执行后被调用一次。与 `on_train_end` 和 `on_eval_end` 方法具有兼容性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: epoch_begin(run_context)
|
||||
|
||||
在每个epoch开始之前被调用。与 `on_train_epoch_begin` 和 `on_eval_epoch_begin` 方法具有兼容性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: epoch_end(run_context)
|
||||
|
||||
在每个epoch结束后被调用。与 `on_train_epoch_end` 和 `on_eval_epoch_end` 方法具有兼容性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_begin(run_context)
|
||||
|
||||
在网络执行推理之前调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_end(run_context)
|
||||
|
||||
网络执行推理之后调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_epoch_begin(run_context)
|
||||
|
||||
在推理的epoch开始之前被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_epoch_end(run_context)
|
||||
|
||||
在推理的epoch结束后被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_step_begin(run_context)
|
||||
|
||||
在推理的每个step开始之前被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_step_end(run_context)
|
||||
|
||||
在推理的每个step完成后被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_begin(run_context)
|
||||
|
||||
在网络执行训练之前调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_end(run_context)
|
||||
|
||||
网络训练执行结束时调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_begin(run_context)
|
||||
|
||||
在训练的每个epoch开始之前被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
在训练的每个epoch结束后被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_step_begin(run_context)
|
||||
|
||||
在训练的每个step开始之前被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_step_end(run_context)
|
||||
|
||||
在训练的每个step完成后被调用。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: step_begin(run_context)
|
||||
|
||||
在每个step开始之前被调用。与 `on_train_step_begin` 和 `on_eval_step_begin` 方法具有兼容性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: step_end(run_context)
|
||||
|
||||
在每个step完成后被调用。与 `on_train_step_end` 和 `on_eval_step_end` 方法具有兼容性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -13,4 +13,4 @@ mindspore.train.LearningRateScheduler
|
|||
在step结束时更改学习率。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -26,7 +26,7 @@ mindspore.train.ModelCheckpoint
|
|||
在训练结束后,会保存最后一个step的checkpoint。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: latest_ckpt_file_name
|
||||
:property:
|
||||
|
@ -38,4 +38,4 @@ mindspore.train.ModelCheckpoint
|
|||
在step结束时保存checkpoint。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -19,6 +19,7 @@ Callback
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.train.BackupAndRestore
|
||||
mindspore.train.Callback
|
||||
mindspore.train.CheckpointConfig
|
||||
mindspore.train.EarlyStopping
|
||||
|
|
|
@ -32,6 +32,14 @@ class JitConfig:
|
|||
|
||||
task_sink (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
**kwargs (dict): A dictionary of keyword arguments that the class needs.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.common.jit_config import JitConfig
|
||||
>>>
|
||||
>>> jitconfig = JitConfig(jit_level="O1")
|
||||
>>> net = LeNet5()
|
||||
>>>
|
||||
>>> net.set_jit_config(jitconfig)
|
||||
"""
|
||||
def __init__(self, jit_level="O1", task_sink=True, **kwargs):
|
||||
if jit_level not in ["O0", "O1", "O2"]:
|
||||
|
|
|
@ -73,7 +73,7 @@ class BackupAndRestore(Callback):
|
|||
Load the backup checkpoint file at the beginning of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if os.path.exists(self.backup_file):
|
||||
|
@ -84,10 +84,10 @@ class BackupAndRestore(Callback):
|
|||
|
||||
def on_train_epoch_end(self, run_context):
|
||||
"""
|
||||
Print process cost time at the end of epoch.
|
||||
Backup checkpoint file at the end of train epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
|
@ -98,6 +98,13 @@ class BackupAndRestore(Callback):
|
|||
save_checkpoint(train_net, self.backup_file)
|
||||
|
||||
def on_train_end(self, run_context):
|
||||
"""
|
||||
Deleted checkpoint file at the end of train.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
run_context.original_args()
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch_num = cb_params.cur_epoch_num
|
||||
|
|
|
@ -57,7 +57,7 @@ class TimeMonitor(Callback):
|
|||
Record time at the beginning of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
self.epoch_time = time.time()
|
||||
|
@ -67,7 +67,7 @@ class TimeMonitor(Callback):
|
|||
Print process cost time at the end of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
epoch_seconds = (time.time() - self.epoch_time) * 1000
|
||||
|
|
Loading…
Reference in New Issue