!30961 add comments for history and lambda callbacks
Merge pull request !30961 from liutongtong9/code_docs_cbcomment
This commit is contained in:
commit
2484bf4812
|
@ -0,0 +1,24 @@
|
|||
.. py:class:: mindspore.train.callback.History()
|
||||
|
||||
将网络输出的相关信息记录到 `History` 对象中。
|
||||
|
||||
如果用户不自定义训练网络或评估网络,记录的内容将为损失值;如果用户自定义了训练网络/评估网络,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。
|
||||
|
||||
.. note::
|
||||
通常使用在 `mindspore.Model.train` 中。
|
||||
|
||||
.. py:method:: begin(run_context)
|
||||
|
||||
训练开始时初始化History对象的epoch属性。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
|
||||
.. py:method:: epoch_end(run_context)
|
||||
|
||||
epoch结束时记录网络输出的相关信息。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
|
@ -0,0 +1,16 @@
|
|||
.. py:class:: mindspore.train.callback.LambdaCallback()
|
||||
|
||||
用于自定义简单的callback。
|
||||
|
||||
使用匿名函数构建callback,定义的匿名函数将在 `mindspore.Model.{train | eval}` 的对应阶段被调用。
|
||||
|
||||
请注意,callback的每个阶段都需要一个位置参数:`run_context`。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **epoch_begin** (Function) - 每个epoch开始时被调用。
|
||||
- **epoch_end** (Function) - 每个epoch结束时被调用。
|
||||
- **step_begin** (Function) - 每个step开始时被调用。
|
||||
- **step_end** (Function) - 每个step结束时被调用。
|
||||
- **begin** (Function) - 模型训练、评估开始时被调用。
|
||||
- **end** (Function) - 模型训练、评估结束时被调用。
|
|
@ -20,10 +20,12 @@ from ._callback import Callback
|
|||
|
||||
class History(Callback):
|
||||
"""
|
||||
Records the first element of network outputs into a `History` object.
|
||||
Records the network outputs information into a `History` object.
|
||||
|
||||
The first element of network outputs is the loss value if not
|
||||
custimizing the train network or eval network.
|
||||
The network outputs information will be the loss value if not custimizing the train network or eval network;
|
||||
if the custimized network returns a `Tensor` or `numpy.ndarray`, the mean value of network output
|
||||
will be recorded, if the custimized network returns a `tuple` or `list`, the first element of network
|
||||
outputs will be recorded.
|
||||
|
||||
Note:
|
||||
Normally used in `mindspore.Model.train`.
|
||||
|
|
|
@ -26,12 +26,12 @@ class LambdaCallback(Callback):
|
|||
Note that each stage of callbacks expects one positional arguments: `run_context`.
|
||||
|
||||
Args:
|
||||
epoch_begin: called at the beginning of every epoch.
|
||||
epoch_end: called at the end of every epoch.
|
||||
step_begin: called at the beginning of every batch.
|
||||
step_end: called at the end of every batch.
|
||||
begin: called at the beginning of model train/eval.
|
||||
end: called at the end of model train/eval.
|
||||
epoch_begin (Function): called at the beginning of every epoch.
|
||||
epoch_end (Function): called at the end of every epoch.
|
||||
step_begin (Function): called at the beginning of every batch.
|
||||
step_end (Function): called at the end of every batch.
|
||||
begin (Function): called at the beginning of model train/eval.
|
||||
end (Function): called at the end of model train/eval.
|
||||
|
||||
Example:
|
||||
>>> from mindspore import Model, nn
|
||||
|
|
|
@ -940,7 +940,8 @@ class Model:
|
|||
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
||||
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
||||
|
||||
if callbacks and (isinstance(callbacks, History) or any(isinstance(cb, History) for cb in callbacks)):
|
||||
if (isinstance(callbacks, Callback) and isinstance(callbacks, History)) or \
|
||||
(isinstance(callbacks, list) and any(isinstance(cb, History) for cb in callbacks)):
|
||||
logger.warning("History callback is recommended to be used in training process.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
|
|
Loading…
Reference in New Issue