forked from mindspore-Ecosystem/mindspore
enhance callback module and strongly check callbacks is list or not
This commit is contained in:
parent
932b7649e7
commit
4c0d12fd63
|
@ -18,6 +18,8 @@ from ._callback import Callback
|
|||
from ._callback import CallbackManager as _CallbackManager
|
||||
from ._callback import InternalCallbackParam as _InternalCallbackParam
|
||||
from ._callback import RunContext
|
||||
from ._callback import checkpoint_cb_for_save_op as _checkpoint_cb_for_save_op
|
||||
from ._callback import set_cur_net as _set_cur_net
|
||||
from ._checkpoint import CheckpointConfig
|
||||
from ._checkpoint import CheckpointManager as _CheckpointManager
|
||||
from ._checkpoint import ModelCheckpoint
|
||||
|
|
|
@ -160,16 +160,25 @@ class CallbackManager(Callback):
|
|||
self._callbacks, self._stack = [], None
|
||||
if isinstance(callbacks, Callback):
|
||||
self._callbacks.append(callbacks)
|
||||
elif callbacks is not None:
|
||||
elif isinstance(callbacks, list):
|
||||
for cb in callbacks:
|
||||
if not isinstance(cb, Callback):
|
||||
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
||||
raise TypeError("The 'callbacks' contains not-a-Callback item.")
|
||||
self._callbacks.append(cb)
|
||||
elif callbacks is not None:
|
||||
raise TypeError("The 'callbacks' is not a Callback or a list of Callback.")
|
||||
|
||||
def __enter__(self):
|
||||
if self._stack is None:
|
||||
self._stack = ExitStack().__enter__()
|
||||
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
||||
callbacks, self._stack = [], ExitStack().__enter__()
|
||||
for callback in self._callbacks:
|
||||
target = self._stack.enter_context(callback)
|
||||
if not isinstance(target, Callback):
|
||||
logger.warning("Please return 'self' or a Callback as the enter target.")
|
||||
callbacks.append(callback)
|
||||
else:
|
||||
callbacks.append(target)
|
||||
self._callbacks = callbacks
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
|
|
|
@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
|
||||
_CallbackManager, Callback, CheckpointConfig
|
||||
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
|
||||
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op():
|
|||
one_param['name'] = "conv1.weight"
|
||||
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
|
||||
parameter_list.append(one_param)
|
||||
checkpoint_cb_for_save_op(parameter_list)
|
||||
_checkpoint_cb_for_save_op(parameter_list)
|
||||
|
||||
|
||||
def test_checkpoint_cb_for_save_op_update_net():
|
||||
|
@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net():
|
|||
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
|
||||
parameter_list.append(one_param)
|
||||
net = Net()
|
||||
set_cur_net(net)
|
||||
checkpoint_cb_for_save_op(parameter_list)
|
||||
_set_cur_net(net)
|
||||
_checkpoint_cb_for_save_op(parameter_list)
|
||||
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue