fallback lossmonitor has_trained_epoch

This commit is contained in:
changzherui 2022-04-11 22:27:32 +08:00
parent 602c0d2923
commit 865a5de03d
4 changed files with 12 additions and 47 deletions

View File

@ -1,4 +1,4 @@
.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1, has_trained_epoch=0)
.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1)
监控训练的loss。
@ -10,12 +10,10 @@
**参数:**
- **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值1。
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch如果设置了该参数LossMonitor将监控该数值之后epoch的loss值。默认值0。
**异常:**
- **ValueError** - 当 `per_print_times` 不是整数或小于零。
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
.. py:method:: step_end(run_context)

View File

@ -909,12 +909,10 @@ class Cell(Cell_):
NOTE:
This is an experimental interface that is subject to change or deletion.
"""
self._dynamic_shape_inputs = inputs
for ele in self._dynamic_shape_inputs:
if isinstance(ele, (str, int, dict)):
raise TypeError(f"For element in 'set_inputs', the type must be Tensor,\
but got {type(ele)}.")
self._dynamic_shape_inputs = inputs
raise TypeError(f"For element in 'set_inputs', the type must be Tensor, but got {type(ele)}.")
def get_inputs(self):
"""

View File

@ -29,17 +29,13 @@ class LossMonitor(Callback):
Note:
If per_print_times is 0, do not print loss.
Parameter `has_trained_epoch` use for failure recovery scenarios.
Args:
per_print_times (int): How many steps to print once loss. During sink mode, it will print loss in the
nearest step. Default: 1.
has_trained_epoch (int): How many epochs has trained. If this parameter is set, LossMonitor will monitor the
loss after has_trained_epoch's epoch. Default: 0.
Raises:
ValueError: If per_print_times is not an integer or less than zero.
ValueError: If has_trained_epoch is not an integer or less than zero.
Examples:
>>> from mindspore import Model, nn
@ -54,13 +50,11 @@ class LossMonitor(Callback):
>>> model.train(10, dataset, callbacks=loss_monitor)
"""
def __init__(self, per_print_times=1, has_trained_epoch=0):
def __init__(self, per_print_times=1):
super(LossMonitor, self).__init__()
Validator.check_non_negative_int(per_print_times)
Validator.check_non_negative_int(has_trained_epoch)
self._per_print_times = per_print_times
self._last_print_time = 0
self._has_trained_epoch = has_trained_epoch
def step_end(self, run_context):
"""
@ -83,7 +77,7 @@ class LossMonitor(Callback):
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch + self._has_trained_epoch))
cb_params.cur_epoch_num, cur_step_in_epoch))
#In disaster recovery scenario, the cb_params.cur_step_num may be rollback to previous step
# and be less than self._last_print_time, so self._last_print_time need to be updated.
@ -94,5 +88,4 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
self._last_print_time = cb_params.cur_step_num
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + self._has_trained_epoch,
cur_step_in_epoch, loss), flush=True)
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)

View File

@ -146,30 +146,6 @@ def test_loss_monitor_args():
"""
with pytest.raises(ValueError):
LossMonitor(per_print_times=-1)
with pytest.raises(ValueError):
LossMonitor(has_trained_epoch=-100)
def test_loss_monitor_has_trained_epoch():
"""
Feature: callback
Description: Test loss monitor has_trained_epoch args
Expectation: run success
"""
cb_params = _InternalCallbackParam()
run_context = RunContext(cb_params)
loss_cb = LossMonitor(has_trained_epoch=10)
cb_params.cur_epoch_num = 4
cb_params.cur_step_num = 1
cb_params.batch_num = 1
cb_params.net_outputs = Tensor(2.0)
cb_params.epoch_num = 4
loss_cb.begin(run_context)
loss_cb.epoch_begin(run_context)
loss_cb.step_begin(run_context)
loss_cb.step_end(run_context)
loss_cb.epoch_end(run_context)
loss_cb.end(run_context)
def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist():
@ -355,7 +331,7 @@ def test_checkpoint_save_ckpt_with_encryption():
ckpt_cb2.step_end(run_context)
def test_CallbackManager():
def test_callbackmanager():
"""
Feature: callback
Description: Test CallbackManager
@ -377,7 +353,7 @@ def test_CallbackManager():
_CallbackManager(callbacks)
def test_CallbackManager_exit_called():
def test_callbackmanager_exit_called():
"""
Feature: callback
Description: Test CallbackManager exit called
@ -392,7 +368,7 @@ def test_CallbackManager_exit_called():
assert mock_exit.call_count == 2
def test_CallbackManager_exit_called_when_raises():
def test_callbackmanager_exit_called_when_raises():
"""
Feature: callback
Description: Test when CallbackManager exit called
@ -408,7 +384,7 @@ def test_CallbackManager_exit_called_when_raises():
assert mock_exit.call_count == 2
def test_CallbackManager_begin_called():
def test_callbackmanager_begin_called():
"""
Feature: callback
Description: Test CallbackManager called begin
@ -424,7 +400,7 @@ def test_CallbackManager_begin_called():
assert mock_begin.call_count == 2
def test_RunContext():
def test_runcontext():
"""
Feature: callback
Description: Test RunContext init
@ -448,7 +424,7 @@ def test_RunContext():
assert should_stop
def test_Checkpoint_Config():
def test_checkpoint_config():
"""
Feature: callback
Description: Test checkpoint config error args