fallback lossmonitor has_trained_epoch

This commit is contained in:
changzherui 2022-04-11 22:27:32 +08:00
parent 602c0d2923
commit 865a5de03d
4 changed files with 12 additions and 47 deletions

View File

@ -1,4 +1,4 @@
.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1, has_trained_epoch=0) .. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1)
监控训练的loss。 监控训练的loss。
@ -10,12 +10,10 @@
**参数:** **参数:**
- **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值1。 - **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值1。
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch如果设置了该参数LossMonitor将监控该数值之后epoch的loss值。默认值0。
**异常:** **异常:**
- **ValueError** - 当 `per_print_times` 不是整数或小于零。 - **ValueError** - 当 `per_print_times` 不是整数或小于零。
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
.. py:method:: step_end(run_context) .. py:method:: step_end(run_context)

View File

@ -909,12 +909,10 @@ class Cell(Cell_):
NOTE: NOTE:
This is an experimental interface that is subject to change or deletion. This is an experimental interface that is subject to change or deletion.
""" """
self._dynamic_shape_inputs = inputs
for ele in self._dynamic_shape_inputs: for ele in self._dynamic_shape_inputs:
if isinstance(ele, (str, int, dict)): if isinstance(ele, (str, int, dict)):
raise TypeError(f"For element in 'set_inputs', the type must be Tensor,\ raise TypeError(f"For element in 'set_inputs', the type must be Tensor, but got {type(ele)}.")
but got {type(ele)}.")
self._dynamic_shape_inputs = inputs
def get_inputs(self): def get_inputs(self):
""" """

View File

@ -29,17 +29,13 @@ class LossMonitor(Callback):
Note: Note:
If per_print_times is 0, do not print loss. If per_print_times is 0, do not print loss.
Parameter `has_trained_epoch` use for failure recovery scenarios.
Args: Args:
per_print_times (int): How many steps to print once loss. During sink mode, it will print loss in the per_print_times (int): How many steps to print once loss. During sink mode, it will print loss in the
nearest step. Default: 1. nearest step. Default: 1.
has_trained_epoch (int): How many epochs has trained. If this parameter is set, LossMonitor will monitor the
loss after has_trained_epoch's epoch. Default: 0.
Raises: Raises:
ValueError: If per_print_times is not an integer or less than zero. ValueError: If per_print_times is not an integer or less than zero.
ValueError: If has_trained_epoch is not an integer or less than zero.
Examples: Examples:
>>> from mindspore import Model, nn >>> from mindspore import Model, nn
@ -54,13 +50,11 @@ class LossMonitor(Callback):
>>> model.train(10, dataset, callbacks=loss_monitor) >>> model.train(10, dataset, callbacks=loss_monitor)
""" """
def __init__(self, per_print_times=1, has_trained_epoch=0): def __init__(self, per_print_times=1):
super(LossMonitor, self).__init__() super(LossMonitor, self).__init__()
Validator.check_non_negative_int(per_print_times) Validator.check_non_negative_int(per_print_times)
Validator.check_non_negative_int(has_trained_epoch)
self._per_print_times = per_print_times self._per_print_times = per_print_times
self._last_print_time = 0 self._last_print_time = 0
self._has_trained_epoch = has_trained_epoch
def step_end(self, run_context): def step_end(self, run_context):
""" """
@ -83,7 +77,7 @@ class LossMonitor(Callback):
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch + self._has_trained_epoch)) cb_params.cur_epoch_num, cur_step_in_epoch))
#In disaster recovery scenario, the cb_params.cur_step_num may be rollback to previous step #In disaster recovery scenario, the cb_params.cur_step_num may be rollback to previous step
# and be less than self._last_print_time, so self._last_print_time need to be updated. # and be less than self._last_print_time, so self._last_print_time need to be updated.
@ -94,5 +88,4 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times: if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
self._last_print_time = cb_params.cur_step_num self._last_print_time = cb_params.cur_step_num
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + self._has_trained_epoch, print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
cur_step_in_epoch, loss), flush=True)

View File

@ -146,30 +146,6 @@ def test_loss_monitor_args():
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
LossMonitor(per_print_times=-1) LossMonitor(per_print_times=-1)
with pytest.raises(ValueError):
LossMonitor(has_trained_epoch=-100)
def test_loss_monitor_has_trained_epoch():
"""
Feature: callback
Description: Test loss monitor has_trained_epoch args
Expectation: run success
"""
cb_params = _InternalCallbackParam()
run_context = RunContext(cb_params)
loss_cb = LossMonitor(has_trained_epoch=10)
cb_params.cur_epoch_num = 4
cb_params.cur_step_num = 1
cb_params.batch_num = 1
cb_params.net_outputs = Tensor(2.0)
cb_params.epoch_num = 4
loss_cb.begin(run_context)
loss_cb.epoch_begin(run_context)
loss_cb.step_begin(run_context)
loss_cb.step_end(run_context)
loss_cb.epoch_end(run_context)
loss_cb.end(run_context)
def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist(): def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist():
@ -355,7 +331,7 @@ def test_checkpoint_save_ckpt_with_encryption():
ckpt_cb2.step_end(run_context) ckpt_cb2.step_end(run_context)
def test_CallbackManager(): def test_callbackmanager():
""" """
Feature: callback Feature: callback
Description: Test CallbackManager Description: Test CallbackManager
@ -377,7 +353,7 @@ def test_CallbackManager():
_CallbackManager(callbacks) _CallbackManager(callbacks)
def test_CallbackManager_exit_called(): def test_callbackmanager_exit_called():
""" """
Feature: callback Feature: callback
Description: Test CallbackManager exit called Description: Test CallbackManager exit called
@ -392,7 +368,7 @@ def test_CallbackManager_exit_called():
assert mock_exit.call_count == 2 assert mock_exit.call_count == 2
def test_CallbackManager_exit_called_when_raises(): def test_callbackmanager_exit_called_when_raises():
""" """
Feature: callback Feature: callback
Description: Test when CallbackManager exit called Description: Test when CallbackManager exit called
@ -408,7 +384,7 @@ def test_CallbackManager_exit_called_when_raises():
assert mock_exit.call_count == 2 assert mock_exit.call_count == 2
def test_CallbackManager_begin_called(): def test_callbackmanager_begin_called():
""" """
Feature: callback Feature: callback
Description: Test CallbackManager called begin Description: Test CallbackManager called begin
@ -424,7 +400,7 @@ def test_CallbackManager_begin_called():
assert mock_begin.call_count == 2 assert mock_begin.call_count == 2
def test_RunContext(): def test_runcontext():
""" """
Feature: callback Feature: callback
Description: Test RunContext init Description: Test RunContext init
@ -448,7 +424,7 @@ def test_RunContext():
assert should_stop assert should_stop
def test_Checkpoint_Config(): def test_checkpoint_config():
""" """
Feature: callback Feature: callback
Description: Test checkpoint config error args Description: Test checkpoint config error args