fix hisotry and lambda callback
This commit is contained in:
parent
0373d2f915
commit
fa3e2938d2
|
@ -2,7 +2,7 @@
|
|||
|
||||
将网络输出的相关信息记录到 `History` 对象中。
|
||||
|
||||
如果用户不自定义训练网络或评估网络,记录的内容将为损失值;如果用户自定义了训练网络/评估网络,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。
|
||||
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。
|
||||
|
||||
.. note::
|
||||
通常使用在 `mindspore.Model.train` 中。
|
||||
|
|
|
@ -1285,6 +1285,7 @@ class Cell(Cell_):
|
|||
... out = self.conv(x)
|
||||
... return out
|
||||
>>> names = []
|
||||
>>> n = Net()
|
||||
>>> for m in n.cells_and_names():
|
||||
... if m[0]:
|
||||
... names.append(m[0])
|
||||
|
|
|
@ -31,6 +31,9 @@ class History(Callback):
|
|||
Normally used in `mindspore.Model.train`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> from mindspore.train.callback import History
|
||||
>>> from mindspore import Model, nn
|
||||
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
|
||||
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
|
||||
|
@ -42,7 +45,7 @@ class History(Callback):
|
|||
>>> model.train(2, train_dataset, callbacks=[history_cb])
|
||||
>>> print(history_cb.epoch)
|
||||
>>> print(history_cb.history)
|
||||
[1, 2]
|
||||
{'epoch': [1, 2]}
|
||||
{'net_output': [1.607877, 1.6033841]}
|
||||
"""
|
||||
def __init__(self):
|
||||
|
@ -54,20 +57,20 @@ class History(Callback):
|
|||
Initialize the `epoch` property at the begin of training.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the `mindspore.Model.train/eval`.
|
||||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
|
||||
"""
|
||||
self.epoch = []
|
||||
self.epoch = {"epoch": []}
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Records the first element of network outputs at the end of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the `mindspore.Model.train/eval`.
|
||||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
epoch = cb_params.get("cur_epoch_num", 1)
|
||||
self.epoch.append(epoch)
|
||||
self.epoch.get("epoch").append(epoch)
|
||||
net_output = cb_params.net_outputs
|
||||
if isinstance(net_output, (tuple, list)):
|
||||
if isinstance(net_output[0], Tensor) and isinstance(net_output[0].asnumpy(), np.ndarray):
|
||||
|
|
|
@ -33,7 +33,10 @@ class LambdaCallback(Callback):
|
|||
begin (Function): called at the beginning of model train/eval.
|
||||
end (Function): called at the end of model train/eval.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> from mindspore.train.callback import History
|
||||
>>> from mindspore import Model, nn
|
||||
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
|
||||
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
|
||||
|
|
|
@ -27,7 +27,7 @@ from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist
|
|||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from .._checkparam import check_input_data, check_output_data, Validator
|
||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, History
|
||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
|
||||
|
@ -940,10 +940,6 @@ class Model:
|
|||
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
||||
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
||||
|
||||
if (isinstance(callbacks, Callback) and isinstance(callbacks, History)) or \
|
||||
(isinstance(callbacks, list) and any(isinstance(cb, History) for cb in callbacks)):
|
||||
logger.warning("History callback is recommended to be used in training process.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
|
|
Loading…
Reference in New Issue