!30961 add comments for history and lambda callbacks

Merge pull request !30961 from liutongtong9/code_docs_cbcomment
This commit is contained in:
i-robot 2022-03-09 09:48:41 +00:00 committed by Gitee
commit 2484bf4812
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 53 additions and 10 deletions

View File

@ -0,0 +1,24 @@
.. py:class:: mindspore.train.callback.History()
将网络输出的相关信息记录到 `History` 对象中。
如果用户不自定义训练网络或评估网络,记录的内容将为损失值;如果用户自定义了训练网络/评估网络,如果定义的网络返回 `Tensor``numpy.ndarray`,则记录此返回值均值,如果返回 `tuple``list`,则记录第一个元素。
.. note::
通常使用在 `mindspore.Model.train` 中。
.. py:method:: begin(run_context)
训练开始时初始化History对象的epoch属性。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: epoch_end(run_context)
epoch结束时记录网络输出的相关信息。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。

View File

@ -0,0 +1,16 @@
.. py:class:: mindspore.train.callback.LambdaCallback()
用于自定义简单的callback。
使用匿名函数构建callback定义的匿名函数将在 `mindspore.Model.{train | eval}` 的对应阶段被调用。
请注意callback的每个阶段都需要一个位置参数`run_context`
**参数:**
- **epoch_begin** (Function) - 每个epoch开始时被调用。
- **epoch_end** (Function) - 每个epoch结束时被调用。
- **step_begin** (Function) - 每个step开始时被调用。
- **step_end** (Function) - 每个step结束时被调用。
- **begin** (Function) - 模型训练、评估开始时被调用。
- **end** (Function) - 模型训练、评估结束时被调用。

View File

@ -20,10 +20,12 @@ from ._callback import Callback
class History(Callback):
"""
Records the first element of network outputs into a `History` object.
Records the network outputs information into a `History` object.
The first element of network outputs is the loss value if not
custimizing the train network or eval network.
The network outputs information will be the loss value if not custimizing the train network or eval network;
if the custimized network returns a `Tensor` or `numpy.ndarray`, the mean value of network output
will be recorded, if the custimized network returns a `tuple` or `list`, the first element of network
outputs will be recorded.
Note:
Normally used in `mindspore.Model.train`.

View File

@ -26,12 +26,12 @@ class LambdaCallback(Callback):
Note that each stage of callbacks expects one positional arguments: `run_context`.
Args:
epoch_begin: called at the beginning of every epoch.
epoch_end: called at the end of every epoch.
step_begin: called at the beginning of every batch.
step_end: called at the end of every batch.
begin: called at the beginning of model train/eval.
end: called at the end of model train/eval.
epoch_begin (Function): called at the beginning of every epoch.
epoch_end (Function): called at the end of every epoch.
step_begin (Function): called at the beginning of every batch.
step_end (Function): called at the end of every batch.
begin (Function): called at the beginning of model train/eval.
end (Function): called at the end of model train/eval.
Example:
>>> from mindspore import Model, nn

View File

@ -940,7 +940,8 @@ class Model:
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
if callbacks and (isinstance(callbacks, History) or any(isinstance(cb, History) for cb in callbacks)):
if (isinstance(callbacks, Callback) and isinstance(callbacks, History)) or \
(isinstance(callbacks, list) and any(isinstance(cb, History) for cb in callbacks)):
logger.warning("History callback is recommended to be used in training process.")
cb_params = _InternalCallbackParam()