remove has_trained_epoch

This commit is contained in:
liutongtong 2022-04-11 16:07:25 +08:00
parent cb9fffaeeb
commit cf436e439e
2 changed files with 2 additions and 23 deletions

View File

@ -7,14 +7,6 @@
.. note::
通常使用在 `mindspore.Model.train` 中。
**参数:**
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch如果设置了该参数History将监控该数值之后epoch的网络输出信息。默认值0。
**异常:**
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
.. py:method:: begin(run_context)
训练开始时初始化History对象的epoch属性。

View File

@ -16,7 +16,6 @@
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator
from ._callback import Callback
@ -32,13 +31,6 @@ class History(Callback):
Note:
Normally used in `mindspore.Model.train`.
Args:
has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the
network output information after has_trained_epoch's epoch. Default: 0.
Raises:
ValueError: If has_trained_epoch is not an integer or less than zero.
Examples:
>>> import numpy as np
>>> import mindspore.dataset as ds
@ -57,11 +49,9 @@ class History(Callback):
{'epoch': [1, 2]}
{'net_output': [1.607877, 1.6033841]}
"""
def __init__(self, has_trained_epoch=0):
def __init__(self):
super(History, self).__init__()
Validator.check_non_negative_int(has_trained_epoch)
self.history = {}
self._has_trained_epoch = has_trained_epoch
def begin(self, run_context):
"""
@ -80,10 +70,7 @@ class History(Callback):
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
"""
cb_params = run_context.original_args()
if "cur_epoch_num" in cb_params:
epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch
else:
epoch = 1
epoch = cb_params.get("cur_epoch_num", 1)
self.epoch.get("epoch").append(epoch)
net_output = cb_params.net_outputs
if isinstance(net_output, (tuple, list)):