fix hisotry and lambda callback

This commit is contained in:
liutongtong 2022-03-16 12:15:15 +08:00
parent 0373d2f915
commit fa3e2938d2
5 changed files with 15 additions and 12 deletions

View File

@ -2,7 +2,7 @@
将网络输出的相关信息记录到 `History` 对象中。
如果用户不自定义训练网络或评估网络,记录的内容将为损失值;如果用户自定义了训练网络/评估网络,如果定义的网络返回 `Tensor``numpy.ndarray`,则记录此返回值均值,如果返回 `tuple``list`,则记录第一个元素。
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor``numpy.ndarray`,则记录此返回值均值,如果返回 `tuple``list`,则记录第一个元素。
.. note::
通常使用在 `mindspore.Model.train` 中。

View File

@ -1285,6 +1285,7 @@ class Cell(Cell_):
... out = self.conv(x)
... return out
>>> names = []
>>> n = Net()
>>> for m in n.cells_and_names():
... if m[0]:
... names.append(m[0])

View File

@ -31,6 +31,9 @@ class History(Callback):
Normally used in `mindspore.Model.train`.
Examples:
>>> import numpy as np
>>> import mindspore.dataset as ds
>>> from mindspore.train.callback import History
>>> from mindspore import Model, nn
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
@ -42,7 +45,7 @@ class History(Callback):
>>> model.train(2, train_dataset, callbacks=[history_cb])
>>> print(history_cb.epoch)
>>> print(history_cb.history)
[1, 2]
{'epoch': [1, 2]}
{'net_output': [1.607877, 1.6033841]}
"""
def __init__(self):
@ -54,20 +57,20 @@ class History(Callback):
Initialize the `epoch` property at the begin of training.
Args:
run_context (RunContext): Context of the `mindspore.Model.train/eval`.
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
"""
self.epoch = []
self.epoch = {"epoch": []}
def epoch_end(self, run_context):
"""
Records the first element of network outputs at the end of epoch.
Args:
run_context (RunContext): Context of the `mindspore.Model.train/eval`.
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
"""
cb_params = run_context.original_args()
epoch = cb_params.get("cur_epoch_num", 1)
self.epoch.append(epoch)
self.epoch.get("epoch").append(epoch)
net_output = cb_params.net_outputs
if isinstance(net_output, (tuple, list)):
if isinstance(net_output[0], Tensor) and isinstance(net_output[0].asnumpy(), np.ndarray):

View File

@ -33,7 +33,10 @@ class LambdaCallback(Callback):
begin (Function): called at the beginning of model train/eval.
end (Function): called at the end of model train/eval.
Example:
Examples:
>>> import numpy as np
>>> import mindspore.dataset as ds
>>> from mindspore.train.callback import History
>>> from mindspore import Model, nn
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)

View File

@ -27,7 +27,7 @@ from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, History
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
@ -940,10 +940,6 @@ class Model:
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
if (isinstance(callbacks, Callback) and isinstance(callbacks, History)) or \
(isinstance(callbacks, list) and any(isinstance(cb, History) for cb in callbacks)):
logger.warning("History callback is recommended to be used in training process.")
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset