!35380 add EarlyStopping and ReduceLROnPleatu callbacks
Merge pull request !35380 from liutongtong9/add_cbs
This commit is contained in:
commit
fd50ac5dbf
|
@ -79,11 +79,13 @@ mindspore
|
|||
|
||||
mindspore.Callback
|
||||
mindspore.CheckpointConfig
|
||||
mindspore.EarlyStopping
|
||||
mindspore.History
|
||||
mindspore.LambdaCallback
|
||||
mindspore.LearningRateScheduler
|
||||
mindspore.LossMonitor
|
||||
mindspore.ModelCheckpoint
|
||||
mindspore.ReduceLROnPlateau
|
||||
mindspore.RunContext
|
||||
mindspore.TimeMonitor
|
||||
|
||||
|
|
|
@ -47,19 +47,19 @@ mindspore.Callback
|
|||
|
||||
.. py:method:: step_begin(run_context)
|
||||
|
||||
在每个step开始之前被调用。
|
||||
在每个step开始之前被调用。与 `on_train_step_begin` 和 `on_eval_step_begin` 方法具有兼容性。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_begin` 和 `on_eval_step_begin` 方法具有兼容性。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
|
||||
.. py:method:: step_end(run_context)
|
||||
|
||||
在每个step完成后被调用。
|
||||
在每个step完成后被调用。与 `on_train_step_end` 和 `on_eval_step_end` 方法具有兼容性。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_end` 和 `on_eval_step_end` 方法具有兼容性。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
|
||||
.. py:method:: on_train_begin(run_context)
|
||||
|
||||
|
@ -69,7 +69,7 @@ mindspore.Callback
|
|||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
|
||||
.. py:method::on_train_end(run_context)
|
||||
.. py:method:: on_train_end(run_context)
|
||||
|
||||
网络训练执行结束时调用。
|
||||
|
||||
|
@ -133,7 +133,7 @@ mindspore.Callback
|
|||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
|
||||
.. py:method:: on_evalepoch_end(run_context)
|
||||
.. py:method:: on_eval_epoch_end(run_context)
|
||||
|
||||
在推理的epoch结束后被调用。
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
mindspore.EarlyStopping
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.EarlyStopping(monitor='eval_loss', min_delta=0, patience=0, verbose=False, mode='auto', baseline=None, restore_best_weights=False)
|
||||
|
||||
当监控的指标停止改进时停止训练。
|
||||
|
||||
假设 `monitor` 是"accuracy",那么,`mode` 将为"max",因为训练的目标是准确率的提高,`model.fit()` 边训练边验证场景下,将记录 `monitor` 的变化,当在 `patience` 个epoch范围内指标效果变好的程度没有超过 `min_delta` 时,将调用 `run_context.request_stop()` 方法来终止训练。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **monitor** (str) - 监控指标。如果是边训练边推理场景,合法的monitor配置值可以为"loss", "eval_loss"以及实例化 `Model` 时传入的metric名称;如果在训练时不做推理,合法的monitor配置值为"loss"。当monitor为"loss"时,如果训练网络有多个输出,默认取第一个值为训练损失值。默认值:"eval_loss"。
|
||||
- **min_delta** (float) - `monitor` 指标变化的最小阈值,超过此阈值才视为 `monitor` 的变化。默认值:0。
|
||||
- **patience** (int) - `moniter` 相对历史最优值变好超过 `min_delta` 视为当前epoch的模型效果有所改善,`patience` 为等待的无改善epoch的数量。默认值:0。
|
||||
- **verbose** (bool) - 是否打印相关信息。默认值:False。
|
||||
- **mode** (str) - `{'auto', 'min', 'max'}`中的一种,'min'模式下将在指标不再减小时执行早停,'max'模式下将在指标不再增大时执行早停,'auto'模式将根据当前 `monitor` 指标的特点自动设置。默认值:"auto"。
|
||||
- **baseline** (float) - 模型效果的基线,当前 `moniter` 相对历史最优值变好且好于 `baseline` 时,内部的等待epoch计数器被清零。默认值:0。
|
||||
- **restore_best_weights** (bool) - 是否自动保存最优模型的权重。默认值:False。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 当 `mode` 不在`{'auto', 'min', 'max'}`中。
|
||||
- **ValueError** - 当传入的 `monitor` 返回值不是标量。
|
||||
|
||||
.. py:method:: on_train_begin(run_context)
|
||||
|
||||
训练开始时初始化相关的变量。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
训练过程中,若监控指标在等待 `patience` 个epoch后仍没有改善,则停止训练。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
|
||||
.. py:method:: on_train_end(run_context)
|
||||
|
||||
打印是第几个epoch执行早停。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
|
@ -0,0 +1,44 @@
|
|||
mindspore.ReduceLROnPlateau
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.ReduceLROnPlateau(monitor='eval_loss', factor=0.1, patience=10, verbose=False, mode='auto', min_delta=1e-4, cooldown=0, min_lr=0)
|
||||
|
||||
当 `monitor` 停止改进时降低学习率。
|
||||
|
||||
模型通常受益于学习率的改变,此回调监控训练过程,当在 `patience` 个epoch范围内指标效果变好的程度没有超过 `min_delta` 时,根据 `factor` 的设置值降低学习率。
|
||||
|
||||
.. note::
|
||||
暂不支持分组学习率场景。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **monitor** (str) - 监控指标。如果是边训练边推理场景,合法的monitor配置值可以为"loss", "eval_loss"以及实例化 `Model` 时传入的metric名称;如果在训练时不做推理,合法的monitor配置值为"loss"。当monitor为"loss"时,如果训练网络有多个输出,默认取第一个值为训练损失值。默认值:"eval_loss"。
|
||||
- **factor** (float) - 学习率变化系数,范围在0-1之间。默认值:0.1。
|
||||
- **patience** (int) - `moniter` 相对历史最优值变好超过 `min_delta` 视为当前epoch的模型效果有所改善,`patience` 为等待的无改善epoch的数量。默认值:10。
|
||||
- **verbose** (bool) - 是否打印相关信息。默认值:False。
|
||||
- **mode** (str) - `{'auto', 'min', 'max'}`中的一种,'min'模式下将在指标不再减小时改变学习率,'max'模式下将在指标不再增大时改变学习率,'auto'模式将根据当前 `monitor` 指标的特点自动设置。默认值:"auto"。
|
||||
- **min_delta** (float) - `monitor` 指标变化的最小阈值,超过此阈值才视为 `monitor` 的变化。默认值:1e-4。
|
||||
- **cooldown** (int) - 减小学习率后,在接下来的 `cooldown` 个epoch中不执行操作。默认值:0。
|
||||
- **min_lr** (float) - 学习率最小设定值。默认值:0。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 当 `mode` 不在`{'auto', 'min', 'max'}`中。
|
||||
- **ValueError** - 分组学习率或动态学习率场景下,当获取到的学习率不是parameter类型。
|
||||
- **ValueError** - 当传入的 `monitor` 返回值不是标量。
|
||||
|
||||
.. py:method:: on_train_begin(run_context)
|
||||
|
||||
训练开始时初始化相关的变量。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
训练过程中,若监控指标在等待 `patience` 个epoch后仍没有改善,则改变学习率。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
|
@ -30,26 +30,26 @@ DataType
|
|||
--------
|
||||
|
||||
.. class:: mindspore.dtype
|
||||
|
||||
|
||||
Create a data type object of MindSpore.
|
||||
|
||||
|
||||
The actual path of ``dtype`` is ``/mindspore/common/dtype.py``.
|
||||
Run the following command to import the package:
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
||||
* **Numeric Type**
|
||||
|
||||
|
||||
Currently, MindSpore supports ``Int`` type, ``Uint`` type, ``Float`` type and ``Complex`` type.
|
||||
The following table lists the details.
|
||||
|
||||
|
||||
============================================== =============================
|
||||
Definition Description
|
||||
============================================== =============================
|
||||
``mindspore.int8`` , ``mindspore.byte`` 8-bit integer
|
||||
``mindspore.int16`` , ``mindspore.short`` 16-bit integer
|
||||
``mindspore.int16`` , ``mindspore.short`` 16-bit integer
|
||||
``mindspore.int32`` , ``mindspore.intc`` 32-bit integer
|
||||
``mindspore.int64`` , ``mindspore.intp`` 64-bit integer
|
||||
``mindspore.uint8`` , ``mindspore.ubyte`` unsigned 8-bit integer
|
||||
|
@ -62,11 +62,11 @@ DataType
|
|||
``mindspore.complex64`` 64-bit complex number
|
||||
``mindspore.complex128`` 128-bit complex number
|
||||
============================================== =============================
|
||||
|
||||
|
||||
* **Other Type**
|
||||
|
||||
|
||||
For other defined types, see the following table.
|
||||
|
||||
|
||||
============================ =================
|
||||
Type Description
|
||||
============================ =================
|
||||
|
@ -85,14 +85,14 @@ DataType
|
|||
``symbolic_key`` The value of a variable is used as a key of the variable in ``env_type`` .
|
||||
``env_type`` Used to store the gradient of the free variable of a function, where the key is the ``symbolic_key`` of the free variable's node and the value is the gradient.
|
||||
============================ =================
|
||||
|
||||
|
||||
* **Tree Topology**
|
||||
|
||||
|
||||
The relationships of the above types are as follows:
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
||||
|
||||
|
||||
└─────── number
|
||||
│ ├─── bool_
|
||||
│ ├─── int_
|
||||
|
@ -194,11 +194,13 @@ Callback
|
|||
|
||||
mindspore.Callback
|
||||
mindspore.CheckpointConfig
|
||||
mindspore.EarlyStopping
|
||||
mindspore.History
|
||||
mindspore.LambdaCallback
|
||||
mindspore.LearningRateScheduler
|
||||
mindspore.LossMonitor
|
||||
mindspore.ModelCheckpoint
|
||||
mindspore.ReduceLROnPlateau
|
||||
mindspore.RunContext
|
||||
mindspore.TimeMonitor
|
||||
|
||||
|
|
|
@ -26,7 +26,8 @@ from .serialization import save_checkpoint, load_checkpoint, load_param_into_net
|
|||
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status,\
|
||||
restore_group_info_list
|
||||
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, CheckpointConfig, \
|
||||
RunContext, LearningRateScheduler, SummaryLandscape, FederatedLearningManager, History, LambdaCallback
|
||||
RunContext, LearningRateScheduler, SummaryLandscape, FederatedLearningManager, History, LambdaCallback, \
|
||||
ReduceLROnPlateau, EarlyStopping
|
||||
from .summary import SummaryRecord
|
||||
from .train_thor import ConvertNetUtils, ConvertModelUtils
|
||||
|
||||
|
|
|
@ -31,7 +31,10 @@ from ._landscape import SummaryLandscape
|
|||
from ._fl_manager import FederatedLearningManager
|
||||
from ._history import History
|
||||
from ._lambda_callback import LambdaCallback
|
||||
from ._early_stop import EarlyStopping
|
||||
from ._reduce_lr_on_plateau import ReduceLROnPlateau
|
||||
|
||||
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
||||
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
|
||||
"FederatedLearningManager", "History", "LambdaCallback"]
|
||||
"FederatedLearningManager", "History", "LambdaCallback", "ReduceLROnPlateau", "EarlyStopping"]
|
||||
|
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ReduceLROnPlateau Callback class."""
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
from mindspore import ops, nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import ReduceOp
|
||||
from mindspore.communication import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
_smaller_better_metrics = ['hausdorff_distance', 'mae', 'mse', 'loss', 'perplexity',
|
||||
'mean_surface_distance', 'root_mean_square_distance', 'eval_loss']
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
"""
|
||||
Stop training when a monitored metric has stopped improving.
|
||||
|
||||
Assuming `monitor` is "accuracy", with this, `mode` would be "max" since
|
||||
goal of trianing is to maximize the accuracy, the `model.fit()` training
|
||||
loop will check at end of epoch whether the accuracy is no longer
|
||||
increasing, considering the `min_delta` and `patience` if applicable.
|
||||
Once it's found no longer increasing, `run_context.request_stop()`
|
||||
will be called and the training terminates.
|
||||
|
||||
Args:
|
||||
monitor (str): quantity to be monitored. If evaluation is performed on
|
||||
the end of train epochs, the valid monitors can be "loss",
|
||||
"eval_loss" or metric names passed when instantiate the `Model`;
|
||||
otherwise the valid monitor is "loss".
|
||||
When monitor is "loss", if train network has multiple outputs,
|
||||
the first element will be returned as training loss.
|
||||
Default: "eval_loss".
|
||||
patience (int): `monitor` value is better than history best value over
|
||||
`min_delta` is seen as improvement, `patience` is number of epochs
|
||||
with no improvement after which the
|
||||
training process will be stopped. Default: 0.
|
||||
verbose (bool): If False: quiet, if True: print related information.
|
||||
Default: True.
|
||||
mode (str): one of `{'auto', 'min', 'max'}`. In "min" mode,
|
||||
the learning rate will be reduced when the
|
||||
quantity monitored has stopped decreasing; in "max" mode it will be
|
||||
reduced when the quantity monitored has stopped increasing; in "auto"
|
||||
mode, the direction is automatically inferred from the name of the
|
||||
monitored quantity. Default: "auto".
|
||||
min_delta (float): threshold for measuring the new optimum, to only focus on
|
||||
significant changes. Default: 0.
|
||||
baseline (float): Baseline value for the monitor. When the monitor value shows
|
||||
improvement over the history best value and the baseline, the internal
|
||||
wait counter will be set to zero. Default: None.
|
||||
restore_best_weights: Whether to restore model weights from
|
||||
the epoch with the best value of the monitored quantity.
|
||||
If False, the model weights obtained at the last step of
|
||||
training are used. Default: False.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> from mindspore.train.callback import EarlyStopping
|
||||
>>> from mindspore import Model, nn
|
||||
>>> net = LeNet5()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"acc"})
|
||||
>>> data_path = './MNIST_Data'
|
||||
>>> dataset = create_dataset(data_path)
|
||||
>>> cb = EarlyStopping(monitor="acc", patience=3, verbose=True)
|
||||
>>> model.fit(10, dataset, callbacks=cb)
|
||||
"""
|
||||
|
||||
def __init__(self, monitor='eval_loss', min_delta=0, patience=0,
|
||||
verbose=False, mode='auto', baseline=None, restore_best_weights=False):
|
||||
super(EarlyStopping, self).__init__()
|
||||
self.monitor = Validator.check_value_type('monitor', monitor, str)
|
||||
min_delta = Validator.check_value_type("min_delta", min_delta, [float, int])
|
||||
self.min_delta = abs(min_delta)
|
||||
self.patience = Validator.check_non_negative_int(patience)
|
||||
self.verbose = Validator.check_bool(verbose)
|
||||
self.mode = Validator.check_value_type('mode', mode, str)
|
||||
self.baseline = Validator.check_value_type("min_delta", min_delta, [float, int]) if baseline else None
|
||||
self.restore_best_weights = Validator.check_bool(restore_best_weights)
|
||||
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
self.best_weights_param_dict = None
|
||||
self._reduce = ValueReduce()
|
||||
|
||||
if self.mode not in ['auto', 'min', 'max']:
|
||||
raise ValueError("mode should be 'auto', 'min' or 'max', but got %s." % self.mode)
|
||||
if self.mode == 'min' or (self.mode == 'auto' and self.monitor in _smaller_better_metrics):
|
||||
self.is_improvement = lambda a, b: np.less(a, b-self.min_delta)
|
||||
self.best = np.Inf
|
||||
else:
|
||||
self.is_improvement = lambda a, b: np.greater(a, b+self.min_delta)
|
||||
self.best = -np.Inf
|
||||
|
||||
def on_train_begin(self, run_context):
|
||||
"""
|
||||
Initialize variables at the begin of training.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
self.best = np.Inf if self.mode == 'min' or \
|
||||
(self.mode == 'auto' and self.monitor in _smaller_better_metrics) else -np.Inf
|
||||
self.best_weights_param_dict = None
|
||||
|
||||
def on_train_epoch_end(self, run_context):
|
||||
"""
|
||||
monitors the training process and if no improvement is seen for a 'patience' number
|
||||
of epochs, the training process will be stopped.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
cur_epoch = cb_params.get("cur_epoch_num")
|
||||
current_value = self._get_monitor_value(cb_params)
|
||||
|
||||
parallel_mode = auto_parallel_context().get_parallel_mode()
|
||||
rank_size = 1 if parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
|
||||
current = current_value if rank_size == 1 else \
|
||||
self._reduce(Tensor(current_value.astype(np.float32))) / rank_size
|
||||
|
||||
if current is None:
|
||||
return
|
||||
if current.shape != ():
|
||||
raise ValueError("EarlyStopping only supports scalar monitor now.")
|
||||
|
||||
if self.restore_best_weights and self.best_weights_param_dict is None:
|
||||
self.best_weights_param_dict = copy.deepcopy(cb_params.train_network.parameters_dict())
|
||||
self.wait += 1
|
||||
if self.is_improvement(current, self.best):
|
||||
self.best = current
|
||||
if self.restore_best_weights:
|
||||
self.best_weights_param_dict = copy.deepcopy(cb_params.train_network.parameters_dict())
|
||||
if self.baseline is None or self.is_improvement(current, self.baseline):
|
||||
self.wait = 0
|
||||
|
||||
if self.wait >= self.patience:
|
||||
self.stopped_epoch = cur_epoch
|
||||
run_context.request_stop()
|
||||
if self.restore_best_weights and self.best_weights_param_dict is not None:
|
||||
if self.verbose:
|
||||
print('Restoring model weights from the end of the best epoch.')
|
||||
load_param_into_net(cb_params.train_network, self.best_weights_param_dict)
|
||||
|
||||
def on_train_end(self, run_context):
|
||||
"""
|
||||
If verbose is True, print the stopped epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
|
||||
if self.stopped_epoch > 0 and self.verbose:
|
||||
print('Epoch %05d: early stopping' % (self.stopped_epoch))
|
||||
|
||||
def _get_monitor_value(self, cb_params):
|
||||
"""
|
||||
Get the monitor value at the end of epoch during training.
|
||||
|
||||
If `mindspore.train.callback.ReduceLROnPlateau` used with `model.train`, no evaluation process
|
||||
during training, only monitor="loss" is valid; if it used with `model.fit`, evaluation process will be
|
||||
performed at the end of epoch, valid monitor is "loss", "eval_loss" and metrics passed to `Model`.
|
||||
|
||||
Args:
|
||||
cb_params (dict): A dictionary stores context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
monitor_value = None
|
||||
if self.monitor == "loss":
|
||||
loss = cb_params.get("net_outputs")
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
monitor_value = loss[0]
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
monitor_value = float(np.mean(loss.asnumpy()))
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
logger.warning("Invalid %s.", self.monitor)
|
||||
monitor_value = loss
|
||||
else:
|
||||
monitor_candidates = cb_params.get("eval_results")
|
||||
if monitor_candidates:
|
||||
monitor_value = monitor_candidates.get(self.monitor)
|
||||
if not monitor_value:
|
||||
logger.warning('Early stopping is conditioned on %s '
|
||||
'which is not available. Available choices are: %s',
|
||||
self.monitor, ["loss"] + list(monitor_candidates.keys()))
|
||||
return np.array(monitor_value) if monitor_value else None
|
||||
|
||||
|
||||
class ValueReduce(nn.Cell):
|
||||
"""
|
||||
Reduces the tensor data across all devices, all devices will get the same final result.
|
||||
For more details, please refer to :class:`mindspore.ops.AllReduce`.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ValueReduce, self).__init__()
|
||||
self.allreduce = ops.AllReduce(ReduceOp.SUM)
|
||||
|
||||
def construct(self, x):
|
||||
return self.allreduce(x).asnumpy()
|
|
@ -0,0 +1,218 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ReduceLROnPlateau Callback class."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import Validator, Rel
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import nn, ops
|
||||
from mindspore.ops import ReduceOp
|
||||
from mindspore.communication import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
_smaller_better_metrics = ['hausdorff_distance', 'mae', 'mse', 'loss', 'perplexity',
|
||||
'mean_surface_distance', 'root_mean_square_distance', 'eval_loss']
|
||||
|
||||
|
||||
class ReduceLROnPlateau(Callback):
|
||||
"""
|
||||
Reduce learning rate when the monitor has stopped improving.
|
||||
|
||||
Models often benefit from reducing the learning rate by a factor
|
||||
of 2-10 once learning stagnates. This callback monitors the training
|
||||
process and if no improvement is seen for a 'patience' number
|
||||
of epochs, the learning rate is reduced.
|
||||
|
||||
Note:
|
||||
Learning rate grouping is not supported now.
|
||||
|
||||
Args:
|
||||
monitor (str): quantity to be monitored. If evaluation is performed on
|
||||
the end of train epochs, the valid monitors can be "loss",
|
||||
"eval_loss" or metric names passed when instantiate the `Model`;
|
||||
otherwise the valid monitor is "loss".
|
||||
When monitor is "loss", if train network has multiple outputs,
|
||||
the first element will be returned as training loss.
|
||||
|
||||
factor (float): factor by which the learning rate will be reduced.
|
||||
`new_lr = lr * factor`. Default: 0.1.
|
||||
patience (int): `monitor` value is better than history best value over
|
||||
`min_delta` is seen as improvement, `patience` is number of epochs
|
||||
with no improvement after which learning rate
|
||||
will be reduced. Default: 10.
|
||||
verbose (bool): If False: quiet, if True: print related information.
|
||||
Default: False.
|
||||
mode (str): one of `{'auto', 'min', 'max'}`. In "min" mode,
|
||||
the learning rate will be reduced when the
|
||||
quantity monitored has stopped decreasing; in "max" mode it will be
|
||||
reduced when the quantity monitored has stopped increasing; in "auto"
|
||||
mode, the direction is automatically inferred from the name of the
|
||||
monitored quantity. Default: "auto".
|
||||
min_delta (float): threshold for measuring the new optimum, to only focus on
|
||||
significant changes. Default: 1e-4.
|
||||
cooldown (int): number of epochs to wait before resuming normal operation after
|
||||
lr has been reduced. Default: 0.
|
||||
min_lr (float): lower bound on the learning rate. Default: 0.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> from mindspore.train.callback import ReduceLROnPlateau
|
||||
>>> from mindspore import Model, nn
|
||||
>>> net = LeNet5()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"acc"})
|
||||
>>> data_path = './MNIST_Data'
|
||||
>>> dataset = create_dataset(data_path)
|
||||
>>> cb = ReduceLROnPlateau(monitor="acc", patience=3, verbose=True)
|
||||
>>> model.fit(10, dataset, callbacks=cb)
|
||||
"""
|
||||
def __init__(self, monitor='eval_loss', factor=0.1, patience=10, verbose=False,
|
||||
mode='auto', min_delta=1e-4, cooldown=0, min_lr=0):
|
||||
super(ReduceLROnPlateau, self).__init__()
|
||||
self.monitor = Validator.check_value_type('monitor', monitor, str)
|
||||
self.factor = Validator.check_float_range(factor, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
self.patience = Validator.check_non_negative_int(patience)
|
||||
self.verbose = Validator.check_bool(verbose)
|
||||
self.mode = Validator.check_value_type('mode', mode, str)
|
||||
min_delta = Validator.check_value_type("min_delta", min_delta, [float, int])
|
||||
self.min_delta = abs(min_delta)
|
||||
self.cooldown = Validator.check_non_negative_int(cooldown)
|
||||
self.min_lr = Validator.check_value_type("min_lr", min_lr, [float, int])
|
||||
|
||||
self.cooldown_counter = 0
|
||||
self.wait = 0
|
||||
self._reduce = ValueReduce()
|
||||
|
||||
if self.mode not in ['auto', 'min', 'max']:
|
||||
raise ValueError("mode should be 'auto', 'min' or 'max', but got %s." % self.mode)
|
||||
if self.mode == 'min' or (self.mode == 'auto' and self.monitor in _smaller_better_metrics):
|
||||
self.is_improvement = lambda a, b: np.less(a, b-self.min_delta)
|
||||
self.best = np.Inf
|
||||
else:
|
||||
self.is_improvement = lambda a, b: np.greater(a, b+self.min_delta)
|
||||
self.best = -np.Inf
|
||||
|
||||
def on_train_begin(self, run_context):
|
||||
"""
|
||||
Initialize variables at the begin of training.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
self.cooldown_counter = 0
|
||||
self.wait = 0
|
||||
self.best = np.Inf if self.mode == 'min' or \
|
||||
(self.mode == 'auto' and self.monitor in _smaller_better_metrics) else -np.Inf
|
||||
|
||||
def on_train_epoch_end(self, run_context):
|
||||
"""
|
||||
monitors the training process and if no improvement is seen for a 'patience' number
|
||||
of epochs, the learning rate is reduced.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
cur_lr = cb_params.optimizer.learning_rate
|
||||
if not isinstance(cur_lr, Parameter):
|
||||
raise ValueError("ReduceLROnPlateau does not support dynamic learning rate and group learning rate now.")
|
||||
|
||||
current_monitor_value = self._get_monitor_value(cb_params)
|
||||
|
||||
parallel_mode = auto_parallel_context().get_parallel_mode()
|
||||
rank_size = 1 if parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
|
||||
reduce_monitor_value = current_monitor_value if rank_size == 1 else \
|
||||
self._reduce(Tensor(current_monitor_value.astype(np.float32))) / rank_size
|
||||
|
||||
if reduce_monitor_value is None:
|
||||
return
|
||||
if reduce_monitor_value.shape != ():
|
||||
raise ValueError("ReduceLROnPlateau only supports scalar monitor now.")
|
||||
|
||||
if self.cooldown_counter > 0:
|
||||
self.cooldown_counter -= 1
|
||||
self.wait = 0
|
||||
|
||||
if self.is_improvement(reduce_monitor_value, self.best):
|
||||
self.best = reduce_monitor_value
|
||||
self.wait = 0
|
||||
elif self.cooldown_counter <= 0:
|
||||
self.wait += 1
|
||||
if self.wait >= self.patience:
|
||||
if cur_lr > Tensor(self.min_lr):
|
||||
new_lr = max(cur_lr * self.factor, self.min_lr)
|
||||
F.assign(cb_params.optimizer.learning_rate, Tensor(new_lr))
|
||||
if self.verbose:
|
||||
print('Epoch %05d: ReduceLROnPlateau reducing learning rate to %s.'
|
||||
% (cb_params.cur_epoch_num, new_lr))
|
||||
self.cooldown_counter = self.cooldown
|
||||
self.wait = 0
|
||||
|
||||
def _get_monitor_value(self, cb_params):
|
||||
"""
|
||||
Get the monitor value at the end of epoch during training.
|
||||
|
||||
If `mindspore.train.callback.ReduceLROnPlateau` used with `model.train`, no evaluation process
|
||||
during training, only monitor="loss" is valid; if it used with `model.fit`, evaluation process will be
|
||||
performed at the end of epoch, valid monitor is "loss", "eval_loss" and metrics passed to `Model`.
|
||||
|
||||
Args:
|
||||
cb_params (dict): A dictionary stores context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
"""
|
||||
monitor_value = None
|
||||
if self.monitor == "loss":
|
||||
loss = cb_params.get("net_outputs")
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
monitor_value = loss[0]
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
monitor_value = float(np.mean(loss.asnumpy()))
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
logger.warning("Invalid %s.", self.monitor)
|
||||
monitor_value = loss
|
||||
else:
|
||||
monitor_candidates = cb_params.get("eval_results")
|
||||
if monitor_candidates:
|
||||
monitor_value = monitor_candidates.get(self.monitor)
|
||||
if not monitor_value:
|
||||
logger.warning('Learning rate reduction is conditioned on %s '
|
||||
'which is not available. Available choices are: %s',
|
||||
self.monitor, ["loss"] + list(monitor_candidates.keys()))
|
||||
return np.array(monitor_value) if monitor_value else None
|
||||
|
||||
|
||||
class ValueReduce(nn.Cell):
|
||||
"""
|
||||
Reduces the tensor data across all devices, all devices will get the same final result.
|
||||
For more details, please refer to :class:`mindspore.ops.AllReduce`.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ValueReduce, self).__init__()
|
||||
self.allreduce = ops.AllReduce(ReduceOp.SUM)
|
||||
|
||||
def construct(self, x):
|
||||
return self.allreduce(x).asnumpy()
|
|
@ -18,6 +18,7 @@ from functools import wraps
|
|||
|
||||
import os
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -25,7 +26,7 @@ from .serialization import save_checkpoint, load_checkpoint
|
|||
from .callback._checkpoint import ModelCheckpoint
|
||||
from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist
|
||||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from ..nn.metrics import get_metrics, get_metric_fn
|
||||
from .._checkparam import check_input_data, check_output_data, Validator
|
||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor
|
||||
from .callback import __all__ as internal_cb_names
|
||||
|
@ -683,8 +684,9 @@ class Model:
|
|||
need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
||||
if need_exec_callback_epoch_end:
|
||||
list_callback.on_train_epoch_end(run_context)
|
||||
if "metrics" in cb_params:
|
||||
if "metrics" in cb_params or "eval_results" in cb_params:
|
||||
cb_params.pop("metrics")
|
||||
cb_params.pop("eval_results")
|
||||
|
||||
should_stop = run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
|
@ -880,8 +882,9 @@ class Model:
|
|||
self._flush_from_cache(cb_params)
|
||||
|
||||
list_callback.on_train_epoch_end(run_context)
|
||||
if "metrics" in cb_params:
|
||||
if "metrics" in cb_params or "eval_results" in cb_params:
|
||||
cb_params.pop("metrics")
|
||||
cb_params.pop("eval_results")
|
||||
should_stop = run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
@ -1114,8 +1117,8 @@ class Model:
|
|||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
if not isinstance(valid_frequency, (int, list)):
|
||||
raise ValueError(f"For 'Model.fit', the type of 'valid_frequency' must be a list or a integer, but got "
|
||||
f"type {type(valid_frequency)}.")
|
||||
raise TypeError(f"For 'Model.fit', the type of 'valid_frequency' must be a list or a integer, but got "
|
||||
f"type {type(valid_frequency)}.")
|
||||
|
||||
if valid_dataset and not self._metric_fns:
|
||||
raise ValueError("For 'Model.fit', if valid_dataset is not None, the model argument 'metrics' can not be"
|
||||
|
@ -1212,10 +1215,10 @@ class Model:
|
|||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params, add_eval_loss=True)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params, add_eval_loss=True)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
|
||||
"""
|
||||
Evaluation. The data would be passed to network through dataset channel.
|
||||
|
||||
|
@ -1243,15 +1246,22 @@ class Model:
|
|||
cb_params.net_outputs = outputs
|
||||
list_callback.on_eval_step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
if add_eval_loss:
|
||||
eval_loss_fn = get_metric_fn("loss")
|
||||
eval_loss_fn.update(outputs[self._eval_indexes[0]])
|
||||
|
||||
list_callback.on_eval_epoch_end(run_context)
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
if add_eval_loss:
|
||||
eval_loss = eval_loss_fn.eval()
|
||||
cb_params.eval_results = copy.deepcopy(metrics)
|
||||
cb_params.eval_results.update({"eval_loss": eval_loss})
|
||||
list_callback.on_eval_end(run_context)
|
||||
|
||||
return metrics
|
||||
|
||||
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
|
||||
"""
|
||||
Evaluation. The data would be passed to network directly.
|
||||
|
||||
|
@ -1278,11 +1288,18 @@ class Model:
|
|||
cb_params.net_outputs = outputs
|
||||
list_callback.on_eval_step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
if add_eval_loss:
|
||||
eval_loss_fn = get_metric_fn("loss")
|
||||
eval_loss_fn.update(outputs[self._eval_indexes[0]])
|
||||
|
||||
list_callback.on_eval_epoch_end(run_context)
|
||||
valid_dataset.reset()
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
if add_eval_loss:
|
||||
eval_loss = eval_loss_fn.eval()
|
||||
cb_params.eval_results = copy.deepcopy(metrics)
|
||||
cb_params.eval_results.update({"eval_loss": eval_loss})
|
||||
list_callback.on_eval_end(run_context)
|
||||
return metrics
|
||||
|
||||
|
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" test EarlyStopping and ReduceLROnPlateau"""
|
||||
|
||||
import copy
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn, Model
|
||||
from mindspore import dataset as ds
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore.train.callback import RunContext, _InternalCallbackParam, \
|
||||
_CallbackManager, ReduceLROnPlateau, EarlyStopping
|
||||
|
||||
|
||||
def get_data(num, w=4.0, b=5.0):
|
||||
for _ in range(num):
|
||||
x = np.random.uniform(-5.0, 5.0)
|
||||
value = (x * x - x * w + b + np.random.normal(0, 1)) // 12
|
||||
target_onehot = np.zeros(shape=(5,))
|
||||
target_onehot[int(value)] = 1
|
||||
yield np.array([x]).astype(np.float32), target_onehot.astype(np.float32)
|
||||
|
||||
|
||||
def create_dataset(num_data, batch_size=512, repeat_size=1):
|
||||
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
|
||||
input_data = input_data.batch(batch_size)
|
||||
input_data = input_data.repeat(repeat_size)
|
||||
return input_data
|
||||
|
||||
|
||||
def define_model(metrics):
|
||||
net = nn.Dense(1, 5, Normal(0.02))
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
net_opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
model = Model(net, loss_fn=net_loss, optimizer=net_opt, metrics=metrics)
|
||||
return model
|
||||
|
||||
|
||||
def test_reduce_lr_on_plateau_moniter_and_factor():
|
||||
"""
|
||||
Feature: `monitor` and `factor`.
|
||||
Description: check invalid params.
|
||||
Expectation: raise value error.
|
||||
"""
|
||||
|
||||
ReduceLROnPlateau(monitor="unknown_str", patience=0, verbose=True)
|
||||
with pytest.raises(ValueError):
|
||||
ReduceLROnPlateau(factor=1.2, patience=0, verbose=True)
|
||||
|
||||
|
||||
def test_reduce_lr_on_plateau_min_delta():
|
||||
"""
|
||||
Feature: `min_delta`.
|
||||
Description: test whether the learning rate reduces correct.
|
||||
Expectation: The second one should reduce the LR after the first epoch due to high epsilon.
|
||||
"""
|
||||
ds_train = create_dataset(1024, 512)
|
||||
ds_eval = create_dataset(512, 256)
|
||||
model = define_model({"acc", "mae"})
|
||||
callbacks = [ReduceLROnPlateau(monitor='eval_loss', factor=0.1, min_delta=0, patience=1, cooldown=5)]
|
||||
model.fit(2, ds_train, ds_eval, callbacks=callbacks)
|
||||
|
||||
ds_train = create_dataset(4096, 1024)
|
||||
ds_eval = create_dataset(1024, 512)
|
||||
model = define_model({"acc", "mae"})
|
||||
callbacks = [ReduceLROnPlateau(monitor='eval_loss', factor=0.1, min_delta=10, patience=1, cooldown=5)]
|
||||
model.fit(2, ds_train, ds_eval, callbacks=callbacks)
|
||||
|
||||
|
||||
def test_reduce_lr_on_plateau_patience_and_cooldown():
|
||||
"""
|
||||
Feature: `patience` and `cooldown`.
|
||||
Description: test whether the learning rate reduces correct.
|
||||
Expectation: output learning rates match the expectation lrs.
|
||||
"""
|
||||
net = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
|
||||
cb_params = _InternalCallbackParam()
|
||||
run_context = RunContext(cb_params)
|
||||
|
||||
cases = [
|
||||
{"losses": [1.0, 1.1, 1.2], "patience": 2, "cooldown": 0, "lrs": [1.0, 1.0, 0.1]},
|
||||
{"losses": [1.0, 1.1, 0.9, 1.0, 1.1], "patience": 2, "cooldown": 0, "lrs": [1.0, 1.0, 1.0, 1.0, 0.1]},
|
||||
{"losses": [1.0, 1.1, 1.0, 1.0, 1.1], "patience": 2, "cooldown": 0, "lrs": [1.0, 1.0, 0.1, 0.1, 0.01]},
|
||||
{"losses": [1.0, 1.1, 1.0, 1.0, 1.1, 1.2], "patience": 2, "cooldown": 1,
|
||||
"lrs": [1.0, 1.0, 0.1, 0.1, 0.01, 0.01]},
|
||||
{"losses": [1.0, 1.1, 1.0, 1.0, 1.1, 1.2], "patience": 2, "cooldown": 2,
|
||||
"lrs": [1.0, 1.0, 0.1, 0.1, 0.1, 0.01]}
|
||||
]
|
||||
|
||||
for case_i, current_case in enumerate(cases):
|
||||
cb_params.optimizer = Momentum(net.trainable_params(), learning_rate=1.0, momentum=0.9)
|
||||
|
||||
losses, patience, cooldown, lrs_results = current_case["losses"], current_case["patience"], \
|
||||
current_case["cooldown"], current_case["lrs"]
|
||||
|
||||
eval_results = [{'eval_loss': losses[i]} for i in range(len(losses))]
|
||||
callbacks = [ReduceLROnPlateau(monitor='eval_loss', patience=patience, cooldown=cooldown)]
|
||||
lrs = []
|
||||
with _CallbackManager(callbacks) as callbacklist:
|
||||
for i, result in enumerate(eval_results):
|
||||
callbacklist.on_train_epoch_begin(run_context)
|
||||
cb_params.eval_results = result
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
callbacklist.on_train_epoch_end(run_context)
|
||||
cur_lr = cb_params.optimizer.learning_rate.asnumpy()
|
||||
lrs.append(copy.deepcopy(cur_lr))
|
||||
np.allclose(lrs, lrs_results[case_i], atol=1e-7)
|
||||
|
||||
|
||||
def test_earlystopping_monitor_set():
|
||||
"""
|
||||
Feature: `patience` and `cooldown`.
|
||||
Description: test whether the learning rate reduces correct.
|
||||
Expectation: output learning rates match the expectation lrs.
|
||||
"""
|
||||
cases = [
|
||||
('max', 'accuracy'),
|
||||
('min', 'eval_loss'),
|
||||
('auto', 'accuracy'),
|
||||
('auto', 'loss'),
|
||||
]
|
||||
for mode, monitor in cases:
|
||||
ds_train = create_dataset(1024, 512)
|
||||
ds_eval = create_dataset(512, 256)
|
||||
model = define_model({"acc", "mae"})
|
||||
callbacks = [EarlyStopping(patience=0, monitor=monitor, mode=mode, verbose=True)]
|
||||
model.fit(2, ds_train, ds_eval, callbacks=callbacks)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
EarlyStopping(patience=0, monitor="Unknown", mode="Unknown", verbose=True)
|
||||
|
||||
|
||||
def test_earlystopping_with_baseline():
|
||||
"""
|
||||
Feature: `baseline` in EarlyStopping.
|
||||
Description: test whether the stopped epoch correct.
|
||||
Expectation: the stopped epoch match the expectation stop_epoch.
|
||||
"""
|
||||
cases = [
|
||||
{"baseline": 0.3, "accuracy": [0.6, 0.5, 0.7, 0.5, 0.6], "patience": 2, "stop_epoch": 5},
|
||||
{"baseline": 0.55, "accuracy": [0.6, 0.3, 0.5, 0.5], "patience": 2, "stop_epoch": 3},
|
||||
{"baseline": 0.6, "accuracy": [0.5, 0.4, 0.7, 0.6, 0.5, 0.6], "patience": 3, "stop_epoch": 6},
|
||||
]
|
||||
for _, current_case in enumerate(cases):
|
||||
baseline, acc, patience, stop_epoch = current_case["baseline"], current_case["accuracy"], \
|
||||
current_case["patience"], current_case["stop_epoch"]
|
||||
|
||||
eval_results = [{'accuracy': acc[i]} for i in range(len(acc))]
|
||||
callbacks = [EarlyStopping(monitor='accuracy', patience=patience, baseline=baseline, verbose=True)]
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
run_context = RunContext(cb_params)
|
||||
|
||||
with _CallbackManager(callbacks) as callbacklist:
|
||||
for i, result in enumerate(eval_results):
|
||||
callbacklist.on_train_epoch_begin(run_context)
|
||||
cb_params.eval_results = result
|
||||
cb_params.cur_epoch_num = i+1
|
||||
callbacklist.on_train_epoch_end(run_context)
|
||||
if run_context.get_stop_requested():
|
||||
break
|
||||
callbacklist.on_train_end(run_context)
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
assert cur_epoch == stop_epoch
|
||||
|
||||
|
||||
def test_earlystopping_final_weights_when_restoring_model_weights():
|
||||
"""
|
||||
Feature: `restore_best_weights` in EarlyStopping.
|
||||
Description: test whether the model weights saved is correct.
|
||||
Expectation: Giving monitor varies as `losses`, the training process is
|
||||
expected to be stopped at 3rd epoch, restores the weights of the 2nd epoch.
|
||||
"""
|
||||
callbacks = EarlyStopping(patience=1, monitor="eval_loss", verbose=True, restore_best_weights=True)
|
||||
ds_train = create_dataset(1024, 512)
|
||||
model_train = define_model(metrics={"acc"})
|
||||
|
||||
losses = [1.0, 0.8, 1.2, 1.3, 1.4]
|
||||
eval_results = [{'eval_loss': losses[i]} for i in range(len(losses))]
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = model_train.train_network
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.on_train_begin(run_context)
|
||||
for i in range(5):
|
||||
list_callback.on_train_epoch_begin(run_context)
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
for d in ds_train.create_dict_iterator():
|
||||
cb_params.train_network(d["data"], d["label"])
|
||||
if cb_params.cur_epoch_num == 2:
|
||||
best_net_param_dict = copy.deepcopy(cb_params.train_network.parameters_dict())
|
||||
cb_params.eval_results = eval_results[i]
|
||||
list_callback.on_train_epoch_end(run_context)
|
||||
end_net_param_dict = copy.deepcopy(cb_params.train_network.parameters_dict())
|
||||
should_stop = run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
list_callback.on_train_end(run_context)
|
||||
|
||||
for key in ["weight", "bias"]:
|
||||
assert (best_net_param_dict[key].asnumpy() == end_net_param_dict[key].asnumpy()).all()
|
Loading…
Reference in New Issue