forked from mindspore-Ecosystem/mindspore
fallback lossmonitor has_trained_epoch
This commit is contained in:
parent
602c0d2923
commit
865a5de03d
|
@ -1,4 +1,4 @@
|
|||
.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1, has_trained_epoch=0)
|
||||
.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1)
|
||||
|
||||
监控训练的loss。
|
||||
|
||||
|
@ -10,12 +10,10 @@
|
|||
**参数:**
|
||||
|
||||
- **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值:1。
|
||||
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如果设置了该参数,LossMonitor将监控该数值之后epoch的loss值。默认值:0。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** - 当 `per_print_times` 不是整数或小于零。
|
||||
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。
|
||||
|
||||
.. py:method:: step_end(run_context)
|
||||
|
||||
|
|
|
@ -909,12 +909,10 @@ class Cell(Cell_):
|
|||
NOTE:
|
||||
This is an experimental interface that is subject to change or deletion.
|
||||
"""
|
||||
|
||||
self._dynamic_shape_inputs = inputs
|
||||
for ele in self._dynamic_shape_inputs:
|
||||
if isinstance(ele, (str, int, dict)):
|
||||
raise TypeError(f"For element in 'set_inputs', the type must be Tensor,\
|
||||
but got {type(ele)}.")
|
||||
self._dynamic_shape_inputs = inputs
|
||||
raise TypeError(f"For element in 'set_inputs', the type must be Tensor, but got {type(ele)}.")
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
|
|
|
@ -29,17 +29,13 @@ class LossMonitor(Callback):
|
|||
|
||||
Note:
|
||||
If per_print_times is 0, do not print loss.
|
||||
Parameter `has_trained_epoch` use for failure recovery scenarios.
|
||||
|
||||
Args:
|
||||
per_print_times (int): How many steps to print once loss. During sink mode, it will print loss in the
|
||||
nearest step. Default: 1.
|
||||
has_trained_epoch (int): How many epochs has trained. If this parameter is set, LossMonitor will monitor the
|
||||
loss after has_trained_epoch's epoch. Default: 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If per_print_times is not an integer or less than zero.
|
||||
ValueError: If has_trained_epoch is not an integer or less than zero.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Model, nn
|
||||
|
@ -54,13 +50,11 @@ class LossMonitor(Callback):
|
|||
>>> model.train(10, dataset, callbacks=loss_monitor)
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, has_trained_epoch=0):
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossMonitor, self).__init__()
|
||||
Validator.check_non_negative_int(per_print_times)
|
||||
Validator.check_non_negative_int(has_trained_epoch)
|
||||
self._per_print_times = per_print_times
|
||||
self._last_print_time = 0
|
||||
self._has_trained_epoch = has_trained_epoch
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
|
@ -83,7 +77,7 @@ class LossMonitor(Callback):
|
|||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch + self._has_trained_epoch))
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
|
||||
#In disaster recovery scenario, the cb_params.cur_step_num may be rollback to previous step
|
||||
# and be less than self._last_print_time, so self._last_print_time need to be updated.
|
||||
|
@ -94,5 +88,4 @@ class LossMonitor(Callback):
|
|||
|
||||
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
|
||||
self._last_print_time = cb_params.cur_step_num
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + self._has_trained_epoch,
|
||||
cur_step_in_epoch, loss), flush=True)
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
||||
|
|
|
@ -146,30 +146,6 @@ def test_loss_monitor_args():
|
|||
"""
|
||||
with pytest.raises(ValueError):
|
||||
LossMonitor(per_print_times=-1)
|
||||
with pytest.raises(ValueError):
|
||||
LossMonitor(has_trained_epoch=-100)
|
||||
|
||||
|
||||
def test_loss_monitor_has_trained_epoch():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test loss monitor has_trained_epoch args
|
||||
Expectation: run success
|
||||
"""
|
||||
cb_params = _InternalCallbackParam()
|
||||
run_context = RunContext(cb_params)
|
||||
loss_cb = LossMonitor(has_trained_epoch=10)
|
||||
cb_params.cur_epoch_num = 4
|
||||
cb_params.cur_step_num = 1
|
||||
cb_params.batch_num = 1
|
||||
cb_params.net_outputs = Tensor(2.0)
|
||||
cb_params.epoch_num = 4
|
||||
loss_cb.begin(run_context)
|
||||
loss_cb.epoch_begin(run_context)
|
||||
loss_cb.step_begin(run_context)
|
||||
loss_cb.step_end(run_context)
|
||||
loss_cb.epoch_end(run_context)
|
||||
loss_cb.end(run_context)
|
||||
|
||||
|
||||
def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist():
|
||||
|
@ -355,7 +331,7 @@ def test_checkpoint_save_ckpt_with_encryption():
|
|||
ckpt_cb2.step_end(run_context)
|
||||
|
||||
|
||||
def test_CallbackManager():
|
||||
def test_callbackmanager():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test CallbackManager
|
||||
|
@ -377,7 +353,7 @@ def test_CallbackManager():
|
|||
_CallbackManager(callbacks)
|
||||
|
||||
|
||||
def test_CallbackManager_exit_called():
|
||||
def test_callbackmanager_exit_called():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test CallbackManager exit called
|
||||
|
@ -392,7 +368,7 @@ def test_CallbackManager_exit_called():
|
|||
assert mock_exit.call_count == 2
|
||||
|
||||
|
||||
def test_CallbackManager_exit_called_when_raises():
|
||||
def test_callbackmanager_exit_called_when_raises():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test when CallbackManager exit called
|
||||
|
@ -408,7 +384,7 @@ def test_CallbackManager_exit_called_when_raises():
|
|||
assert mock_exit.call_count == 2
|
||||
|
||||
|
||||
def test_CallbackManager_begin_called():
|
||||
def test_callbackmanager_begin_called():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test CallbackManager called begin
|
||||
|
@ -424,7 +400,7 @@ def test_CallbackManager_begin_called():
|
|||
assert mock_begin.call_count == 2
|
||||
|
||||
|
||||
def test_RunContext():
|
||||
def test_runcontext():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test RunContext init
|
||||
|
@ -448,7 +424,7 @@ def test_RunContext():
|
|||
assert should_stop
|
||||
|
||||
|
||||
def test_Checkpoint_Config():
|
||||
def test_checkpoint_config():
|
||||
"""
|
||||
Feature: callback
|
||||
Description: Test checkpoint config error args
|
||||
|
|
Loading…
Reference in New Issue