forked from mindspore-Ecosystem/mindspore
remove has_trained_epoch
This commit is contained in:
parent
cb9fffaeeb
commit
cf436e439e
|
@ -7,14 +7,6 @@
|
|||
.. note::
|
||||
通常使用在 `mindspore.Model.train` 中。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如果设置了该参数,History将监控该数值之后epoch的网络输出信息。默认值:0。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
|
||||
|
||||
.. py:method:: begin(run_context)
|
||||
|
||||
训练开始时初始化History对象的epoch属性。
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
|
@ -32,13 +31,6 @@ class History(Callback):
|
|||
Note:
|
||||
Normally used in `mindspore.Model.train`.
|
||||
|
||||
Args:
|
||||
has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the
|
||||
network output information after has_trained_epoch's epoch. Default: 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If has_trained_epoch is not an integer or less than zero.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -57,11 +49,9 @@ class History(Callback):
|
|||
{'epoch': [1, 2]}
|
||||
{'net_output': [1.607877, 1.6033841]}
|
||||
"""
|
||||
def __init__(self, has_trained_epoch=0):
|
||||
def __init__(self):
|
||||
super(History, self).__init__()
|
||||
Validator.check_non_negative_int(has_trained_epoch)
|
||||
self.history = {}
|
||||
self._has_trained_epoch = has_trained_epoch
|
||||
|
||||
def begin(self, run_context):
|
||||
"""
|
||||
|
@ -80,10 +70,7 @@ class History(Callback):
|
|||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if "cur_epoch_num" in cb_params:
|
||||
epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch
|
||||
else:
|
||||
epoch = 1
|
||||
epoch = cb_params.get("cur_epoch_num", 1)
|
||||
self.epoch.get("epoch").append(epoch)
|
||||
net_output = cb_params.net_outputs
|
||||
if isinstance(net_output, (tuple, list)):
|
||||
|
|
Loading…
Reference in New Issue