diff --git a/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst b/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst index 6c8f0c975a3..30891ffba0c 100644 --- a/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst +++ b/docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst @@ -1,4 +1,4 @@ -.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1, has_trained_epoch=0) +.. py:class:: mindspore.train.callback.LossMonitor(per_print_times=1) 监控训练的loss。 @@ -10,12 +10,10 @@ **参数:** - **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值:1。 - - **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如果设置了该参数,LossMonitor将监控该数值之后epoch的loss值。默认值:0。 **异常:** - **ValueError** - 当 `per_print_times` 不是整数或小于零。 - - **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。 .. py:method:: step_end(run_context) diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index c102bb75f92..edfebcc83e3 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -909,12 +909,10 @@ class Cell(Cell_): NOTE: This is an experimental interface that is subject to change or deletion. """ - + self._dynamic_shape_inputs = inputs for ele in self._dynamic_shape_inputs: if isinstance(ele, (str, int, dict)): - raise TypeError(f"For element in 'set_inputs', the type must be Tensor,\ - but got {type(ele)}.") - self._dynamic_shape_inputs = inputs + raise TypeError(f"For element in 'set_inputs', the type must be Tensor, but got {type(ele)}.") def get_inputs(self): """ diff --git a/mindspore/python/mindspore/train/callback/_loss_monitor.py b/mindspore/python/mindspore/train/callback/_loss_monitor.py index 4c769cbeb2d..abc74e242e4 100644 --- a/mindspore/python/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/python/mindspore/train/callback/_loss_monitor.py @@ -29,17 +29,13 @@ class LossMonitor(Callback): Note: If per_print_times is 0, do not print loss. - Parameter `has_trained_epoch` use for failure recovery scenarios. Args: per_print_times (int): How many steps to print once loss. During sink mode, it will print loss in the nearest step. Default: 1. - has_trained_epoch (int): How many epochs has trained. If this parameter is set, LossMonitor will monitor the - loss after has_trained_epoch's epoch. Default: 0. Raises: ValueError: If per_print_times is not an integer or less than zero. - ValueError: If has_trained_epoch is not an integer or less than zero. Examples: >>> from mindspore import Model, nn @@ -54,13 +50,11 @@ class LossMonitor(Callback): >>> model.train(10, dataset, callbacks=loss_monitor) """ - def __init__(self, per_print_times=1, has_trained_epoch=0): + def __init__(self, per_print_times=1): super(LossMonitor, self).__init__() Validator.check_non_negative_int(per_print_times) - Validator.check_non_negative_int(has_trained_epoch) self._per_print_times = per_print_times self._last_print_time = 0 - self._has_trained_epoch = has_trained_epoch def step_end(self, run_context): """ @@ -83,7 +77,7 @@ class LossMonitor(Callback): if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( - cb_params.cur_epoch_num, cur_step_in_epoch + self._has_trained_epoch)) + cb_params.cur_epoch_num, cur_step_in_epoch)) #In disaster recovery scenario, the cb_params.cur_step_num may be rollback to previous step # and be less than self._last_print_time, so self._last_print_time need to be updated. @@ -94,5 +88,4 @@ class LossMonitor(Callback): if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times: self._last_print_time = cb_params.cur_step_num - print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num + self._has_trained_epoch, - cur_step_in_epoch, loss), flush=True) + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index 1d9d96c8f6e..9a06cba5d71 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -146,30 +146,6 @@ def test_loss_monitor_args(): """ with pytest.raises(ValueError): LossMonitor(per_print_times=-1) - with pytest.raises(ValueError): - LossMonitor(has_trained_epoch=-100) - - -def test_loss_monitor_has_trained_epoch(): - """ - Feature: callback - Description: Test loss monitor has_trained_epoch args - Expectation: run success - """ - cb_params = _InternalCallbackParam() - run_context = RunContext(cb_params) - loss_cb = LossMonitor(has_trained_epoch=10) - cb_params.cur_epoch_num = 4 - cb_params.cur_step_num = 1 - cb_params.batch_num = 1 - cb_params.net_outputs = Tensor(2.0) - cb_params.epoch_num = 4 - loss_cb.begin(run_context) - loss_cb.epoch_begin(run_context) - loss_cb.step_begin(run_context) - loss_cb.step_end(run_context) - loss_cb.epoch_end(run_context) - loss_cb.end(run_context) def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist(): @@ -355,7 +331,7 @@ def test_checkpoint_save_ckpt_with_encryption(): ckpt_cb2.step_end(run_context) -def test_CallbackManager(): +def test_callbackmanager(): """ Feature: callback Description: Test CallbackManager @@ -377,7 +353,7 @@ def test_CallbackManager(): _CallbackManager(callbacks) -def test_CallbackManager_exit_called(): +def test_callbackmanager_exit_called(): """ Feature: callback Description: Test CallbackManager exit called @@ -392,7 +368,7 @@ def test_CallbackManager_exit_called(): assert mock_exit.call_count == 2 -def test_CallbackManager_exit_called_when_raises(): +def test_callbackmanager_exit_called_when_raises(): """ Feature: callback Description: Test when CallbackManager exit called @@ -408,7 +384,7 @@ def test_CallbackManager_exit_called_when_raises(): assert mock_exit.call_count == 2 -def test_CallbackManager_begin_called(): +def test_callbackmanager_begin_called(): """ Feature: callback Description: Test CallbackManager called begin @@ -424,7 +400,7 @@ def test_CallbackManager_begin_called(): assert mock_begin.call_count == 2 -def test_RunContext(): +def test_runcontext(): """ Feature: callback Description: Test RunContext init @@ -448,7 +424,7 @@ def test_RunContext(): assert should_stop -def test_Checkpoint_Config(): +def test_checkpoint_config(): """ Feature: callback Description: Test checkpoint config error args