From 538765694e2da05cd18040ef5501e45b7f627878 Mon Sep 17 00:00:00 2001 From: liutongtong Date: Sat, 26 Mar 2022 12:29:39 +0800 Subject: [PATCH] add args in history --- .../train/mindspore.train.callback.History.rst | 8 ++++++++ .../python/mindspore/train/callback/_history.py | 17 +++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/docs/api/api_python/train/mindspore.train.callback.History.rst b/docs/api/api_python/train/mindspore.train.callback.History.rst index 0c4133349b3..724ff705441 100644 --- a/docs/api/api_python/train/mindspore.train.callback.History.rst +++ b/docs/api/api_python/train/mindspore.train.callback.History.rst @@ -7,6 +7,14 @@ .. note:: 通常使用在 `mindspore.Model.train` 中。 + **参数:** + + - **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如果设置了该参数,History将监控该数值之后epoch的网络输出信息。默认值:0。 + + **异常:** + + - **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。 + .. py:method:: begin(run_context) 训练开始时初始化History对象的epoch属性。 diff --git a/mindspore/python/mindspore/train/callback/_history.py b/mindspore/python/mindspore/train/callback/_history.py index aa30f371dbe..c84679604c4 100644 --- a/mindspore/python/mindspore/train/callback/_history.py +++ b/mindspore/python/mindspore/train/callback/_history.py @@ -16,6 +16,7 @@ import numpy as np from mindspore.common.tensor import Tensor +from mindspore._checkparam import Validator from ._callback import Callback @@ -31,6 +32,13 @@ class History(Callback): Note: Normally used in `mindspore.Model.train`. + Args: + has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the + network output information after has_trained_epoch's epoch. Default: 0. + + Raises: + ValueError: If has_trained_epoch is not an integer or less than zero. + Examples: >>> import numpy as np >>> import mindspore.dataset as ds @@ -49,9 +57,11 @@ class History(Callback): {'epoch': [1, 2]} {'net_output': [1.607877, 1.6033841]} """ - def __init__(self): + def __init__(self, has_trained_epoch=0): super(History, self).__init__() + Validator.check_non_negative_int(has_trained_epoch) self.history = {} + self._has_trained_epoch = has_trained_epoch def begin(self, run_context): """ @@ -70,7 +80,10 @@ class History(Callback): run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. """ cb_params = run_context.original_args() - epoch = cb_params.get("cur_epoch_num", 1) + if "cur_epoch_num" in cb_params: + epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch + else: + epoch = 1 self.epoch.get("epoch").append(epoch) net_output = cb_params.net_outputs if isinstance(net_output, (tuple, list)):