!31983 add has_trained_epoch in history

Merge pull request !31983 from liutongtong9/fix_hisoty
This commit is contained in:
i-robot 2022-03-29 01:36:22 +00:00 committed by Gitee
commit 8616f26a25
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 23 additions and 2 deletions

View File

@ -7,6 +7,14 @@
.. note:: .. note::
通常使用在 `mindspore.Model.train` 中。 通常使用在 `mindspore.Model.train` 中。
**参数:**
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch如果设置了该参数History将监控该数值之后epoch的网络输出信息。默认值0。
**异常:**
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
.. py:method:: begin(run_context) .. py:method:: begin(run_context)
训练开始时初始化History对象的epoch属性。 训练开始时初始化History对象的epoch属性。

View File

@ -16,6 +16,7 @@
import numpy as np import numpy as np
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator
from ._callback import Callback from ._callback import Callback
@ -31,6 +32,13 @@ class History(Callback):
Note: Note:
Normally used in `mindspore.Model.train`. Normally used in `mindspore.Model.train`.
Args:
has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the
network output information after has_trained_epoch's epoch. Default: 0.
Raises:
ValueError: If has_trained_epoch is not an integer or less than zero.
Examples: Examples:
>>> import numpy as np >>> import numpy as np
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@ -49,9 +57,11 @@ class History(Callback):
{'epoch': [1, 2]} {'epoch': [1, 2]}
{'net_output': [1.607877, 1.6033841]} {'net_output': [1.607877, 1.6033841]}
""" """
def __init__(self): def __init__(self, has_trained_epoch=0):
super(History, self).__init__() super(History, self).__init__()
Validator.check_non_negative_int(has_trained_epoch)
self.history = {} self.history = {}
self._has_trained_epoch = has_trained_epoch
def begin(self, run_context): def begin(self, run_context):
""" """
@ -70,7 +80,10 @@ class History(Callback):
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
""" """
cb_params = run_context.original_args() cb_params = run_context.original_args()
epoch = cb_params.get("cur_epoch_num", 1) if "cur_epoch_num" in cb_params:
epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch
else:
epoch = 1
self.epoch.get("epoch").append(epoch) self.epoch.get("epoch").append(epoch)
net_output = cb_params.net_outputs net_output = cb_params.net_outputs
if isinstance(net_output, (tuple, list)): if isinstance(net_output, (tuple, list)):