From 4c0d12fd635fa37c8c0bca069432ecd6034c5d87 Mon Sep 17 00:00:00 2001 From: Li Hongzhang Date: Thu, 18 Jun 2020 20:22:07 +0800 Subject: [PATCH] enhance callback module and strongly check callbacks is list or not --- mindspore/train/callback/__init__.py | 2 ++ mindspore/train/callback/_callback.py | 17 +++++++++++++---- tests/ut/python/utils/test_callback.py | 9 ++++----- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/mindspore/train/callback/__init__.py b/mindspore/train/callback/__init__.py index fd2e4760b3a..6ef171cc875 100644 --- a/mindspore/train/callback/__init__.py +++ b/mindspore/train/callback/__init__.py @@ -18,6 +18,8 @@ from ._callback import Callback from ._callback import CallbackManager as _CallbackManager from ._callback import InternalCallbackParam as _InternalCallbackParam from ._callback import RunContext +from ._callback import checkpoint_cb_for_save_op as _checkpoint_cb_for_save_op +from ._callback import set_cur_net as _set_cur_net from ._checkpoint import CheckpointConfig from ._checkpoint import CheckpointManager as _CheckpointManager from ._checkpoint import ModelCheckpoint diff --git a/mindspore/train/callback/_callback.py b/mindspore/train/callback/_callback.py index 756b9c71183..c75e0996937 100644 --- a/mindspore/train/callback/_callback.py +++ b/mindspore/train/callback/_callback.py @@ -160,16 +160,25 @@ class CallbackManager(Callback): self._callbacks, self._stack = [], None if isinstance(callbacks, Callback): self._callbacks.append(callbacks) - elif callbacks is not None: + elif isinstance(callbacks, list): for cb in callbacks: if not isinstance(cb, Callback): - raise TypeError("%r is not an instance of %r" % (cb, Callback)) + raise TypeError("The 'callbacks' contains not-a-Callback item.") self._callbacks.append(cb) + elif callbacks is not None: + raise TypeError("The 'callbacks' is not a Callback or a list of Callback.") def __enter__(self): if self._stack is None: - self._stack = ExitStack().__enter__() - self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] + callbacks, self._stack = [], ExitStack().__enter__() + for callback in self._callbacks: + target = self._stack.enter_context(callback) + if not isinstance(target, Callback): + logger.warning("Please return 'self' or a Callback as the enter target.") + callbacks.append(callback) + else: + callbacks.append(target) + self._callbacks = callbacks return self def __exit__(self, *err): diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index e4ecfe696a3..a5f2a3323fd 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ - _CallbackManager, Callback, CheckpointConfig -from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op + _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist class Net(nn.Cell): @@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op(): one_param['name'] = "conv1.weight" one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) parameter_list.append(one_param) - checkpoint_cb_for_save_op(parameter_list) + _checkpoint_cb_for_save_op(parameter_list) def test_checkpoint_cb_for_save_op_update_net(): @@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net(): one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) parameter_list.append(one_param) net = Net() - set_cur_net(net) - checkpoint_cb_for_save_op(parameter_list) + _set_cur_net(net) + _checkpoint_cb_for_save_op(parameter_list) assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1