!31983 add has_trained_epoch in history
Merge pull request !31983 from liutongtong9/fix_hisoty
This commit is contained in:
commit
8616f26a25
|
@ -7,6 +7,14 @@
|
||||||
.. note::
|
.. note::
|
||||||
通常使用在 `mindspore.Model.train` 中。
|
通常使用在 `mindspore.Model.train` 中。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如果设置了该参数,History将监控该数值之后epoch的网络输出信息。默认值:0。
|
||||||
|
|
||||||
|
**异常:**
|
||||||
|
|
||||||
|
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
|
||||||
|
|
||||||
.. py:method:: begin(run_context)
|
.. py:method:: begin(run_context)
|
||||||
|
|
||||||
训练开始时初始化History对象的epoch属性。
|
训练开始时初始化History对象的epoch属性。
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore._checkparam import Validator
|
||||||
from ._callback import Callback
|
from ._callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +32,13 @@ class History(Callback):
|
||||||
Note:
|
Note:
|
||||||
Normally used in `mindspore.Model.train`.
|
Normally used in `mindspore.Model.train`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the
|
||||||
|
network output information after has_trained_epoch's epoch. Default: 0.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If has_trained_epoch is not an integer or less than zero.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
|
@ -49,9 +57,11 @@ class History(Callback):
|
||||||
{'epoch': [1, 2]}
|
{'epoch': [1, 2]}
|
||||||
{'net_output': [1.607877, 1.6033841]}
|
{'net_output': [1.607877, 1.6033841]}
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self, has_trained_epoch=0):
|
||||||
super(History, self).__init__()
|
super(History, self).__init__()
|
||||||
|
Validator.check_non_negative_int(has_trained_epoch)
|
||||||
self.history = {}
|
self.history = {}
|
||||||
|
self._has_trained_epoch = has_trained_epoch
|
||||||
|
|
||||||
def begin(self, run_context):
|
def begin(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -70,7 +80,10 @@ class History(Callback):
|
||||||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
|
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
|
||||||
"""
|
"""
|
||||||
cb_params = run_context.original_args()
|
cb_params = run_context.original_args()
|
||||||
epoch = cb_params.get("cur_epoch_num", 1)
|
if "cur_epoch_num" in cb_params:
|
||||||
|
epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch
|
||||||
|
else:
|
||||||
|
epoch = 1
|
||||||
self.epoch.get("epoch").append(epoch)
|
self.epoch.get("epoch").append(epoch)
|
||||||
net_output = cb_params.net_outputs
|
net_output = cb_params.net_outputs
|
||||||
if isinstance(net_output, (tuple, list)):
|
if isinstance(net_output, (tuple, list)):
|
||||||
|
|
Loading…
Reference in New Issue