forked from mindspore-Ecosystem/mindspore
callback as context manager
This commit is contained in:
parent
ea37dc76f0
commit
ee438aaf4a
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from model.dataset_helper import DatasetHelper
|
||||
|
@ -374,7 +374,6 @@ class Model:
|
|||
self._train_network.set_broadcast_flag()
|
||||
|
||||
# build callback list
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
|
@ -385,17 +384,17 @@ class Model:
|
|||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = list_callback
|
||||
cb_params.list_callback = callbacks
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
|
@ -408,7 +407,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
iter_first_order = self._frequency - 1
|
||||
|
@ -473,7 +472,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
|
@ -580,7 +579,7 @@ class Model:
|
|||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -619,7 +618,7 @@ class Model:
|
|||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -678,7 +677,6 @@ class Model:
|
|||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
|
@ -691,9 +689,10 @@ class Model:
|
|||
|
||||
self._clear_metrics()
|
||||
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
|
|
|
@ -18,6 +18,7 @@ import os
|
|||
import stat
|
||||
import shutil
|
||||
import time
|
||||
from contextlib import ExitStack
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
|
@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list):
|
|||
return ret
|
||||
|
||||
|
||||
def _build_callbacks(callbacks):
|
||||
"""
|
||||
Contain a list of callback.
|
||||
|
||||
Args:
|
||||
callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
|
||||
|
||||
Returns:
|
||||
List, a list of callback functions.
|
||||
"""
|
||||
if callbacks:
|
||||
if isinstance(callbacks, tuple):
|
||||
raise TypeError("Callbacks cannot be a tuple. Please check it.")
|
||||
if not isinstance(callbacks, list):
|
||||
callbacks = [callbacks]
|
||||
else:
|
||||
callbacks = []
|
||||
|
||||
excute_callbacks = []
|
||||
for cb in callbacks:
|
||||
if cb is None or not isinstance(cb, Callback):
|
||||
raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.")
|
||||
excute_callbacks.append(cb)
|
||||
|
||||
return _ListCallback(excute_callbacks)
|
||||
|
||||
|
||||
class _ListCallback:
|
||||
"""
|
||||
Sequential execution of callback functions.
|
||||
|
||||
Execute Callback functions at certain points.
|
||||
|
||||
Args:
|
||||
callbacks (list): Callback functions list.
|
||||
"""
|
||||
def __init__(self, callbacks):
|
||||
super(_ListCallback, self).__init__()
|
||||
self._callbacks = callbacks
|
||||
|
||||
def begin(self, run_context):
|
||||
"""Called once before network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.begin(run_context)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_begin(run_context)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Called after each epoch finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_end(run_context)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_begin(run_context)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Called after each step finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_end(run_context)
|
||||
|
||||
def end(self, run_context):
|
||||
"""Called once after network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.end(run_context)
|
||||
|
||||
|
||||
class Callback:
|
||||
"""
|
||||
Abstract base class used to build a callback function.
|
||||
Abstract base class used to build a callback class. Callbacks are context managers
|
||||
which will be entered and exited when passing into the Model.
|
||||
You can leverage this mechanism to init and release resources automatically.
|
||||
|
||||
Callback function will execution some operating to the current step or epoch.
|
||||
|
||||
|
@ -369,8 +301,13 @@ class Callback:
|
|||
>>> print_cb = Print_info()
|
||||
>>> model.train(epoch, dataset, callbacks=print_cb)
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Return the enter target."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
"""Release resources here if have any."""
|
||||
|
||||
def begin(self, run_context):
|
||||
"""
|
||||
|
@ -421,6 +358,67 @@ class Callback:
|
|||
"""
|
||||
|
||||
|
||||
class _CallbackManager(Callback):
|
||||
"""
|
||||
Sequential execution of callback functions.
|
||||
|
||||
Execute Callback functions at certain points.
|
||||
|
||||
Args:
|
||||
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks):
|
||||
self._callbacks, self._stack = [], None
|
||||
if isinstance(callbacks, Callback):
|
||||
self._callbacks.append(callbacks)
|
||||
elif callbacks is not None:
|
||||
for cb in callbacks:
|
||||
if not isinstance(cb, Callback):
|
||||
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
||||
self._callbacks.append(cb)
|
||||
|
||||
def __enter__(self):
|
||||
if self._stack is None:
|
||||
self._stack = ExitStack().__enter__()
|
||||
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._stack.__exit__(*err)
|
||||
|
||||
def begin(self, run_context):
|
||||
"""Called once before network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.begin(run_context)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_begin(run_context)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Called after each epoch finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_end(run_context)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_begin(run_context)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Called after each step finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_end(run_context)
|
||||
|
||||
def end(self, run_context):
|
||||
"""Called once after network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.end(run_context)
|
||||
|
||||
|
||||
|
||||
class SummaryStep(Callback):
|
||||
"""
|
||||
The summary callback class.
|
||||
|
@ -435,6 +433,13 @@ class SummaryStep(Callback):
|
|||
raise ValueError("`flush_step` should be int and greater than 0")
|
||||
self._summary = summary
|
||||
self._flush_step = flush_step
|
||||
def __enter__(self):
|
||||
self._summary.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._summary.__exit__(*err)
|
||||
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore import log as logger
|
|||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||
from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
|
@ -332,8 +332,6 @@ class Model:
|
|||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
# build callback list
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
|
@ -344,17 +342,18 @@ class Model:
|
|||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = list_callback
|
||||
cb_params.list_callback = callbacks
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
# build callback list
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
|
@ -367,7 +366,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
|
@ -415,7 +414,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
|
@ -522,7 +521,7 @@ class Model:
|
|||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -561,7 +560,7 @@ class Model:
|
|||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -620,7 +619,6 @@ class Model:
|
|||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
|
@ -633,9 +631,10 @@ class Model:
|
|||
|
||||
self._clear_metrics()
|
||||
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from .dataset_helper import DatasetHelper
|
||||
|
@ -392,7 +392,6 @@ class Model:
|
|||
self._train_network.set_broadcast_flag()
|
||||
|
||||
# build callback list
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
|
@ -403,17 +402,17 @@ class Model:
|
|||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = list_callback
|
||||
cb_params.list_callback = callbacks
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
|
@ -426,7 +425,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
iter_first_order = self._frequency - 1
|
||||
|
@ -490,7 +489,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
|
@ -695,7 +694,6 @@ class Model:
|
|||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
|
@ -708,9 +706,10 @@ class Model:
|
|||
|
||||
self._clear_metrics()
|
||||
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
|
|
|
@ -156,12 +156,19 @@ def get_dataset():
|
|||
|
||||
|
||||
class ImageSummaryCallback:
|
||||
def __init__(self, summaryRecord):
|
||||
self._summaryRecord = summaryRecord
|
||||
|
||||
def __init__(self, summary_record):
|
||||
self._summary_record = summary_record
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
pass
|
||||
|
||||
def record(self, step, train_network=None):
|
||||
self._summaryRecord.record(step, train_network)
|
||||
self._summaryRecord.flush()
|
||||
self._summary_record.record(step, train_network)
|
||||
self._summary_record.flush()
|
||||
|
||||
|
||||
def test_image_summary_train():
|
||||
|
|
|
@ -180,6 +180,12 @@ class CallbackTest:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
pass
|
||||
|
||||
def record(self, step, *args):
|
||||
print(step, args)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""test callback function."""
|
||||
import os
|
||||
import stat
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
|
||||
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
||||
_build_callbacks, CheckpointConfig, _set_cur_net
|
||||
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode():
|
|||
run_context = RunContext(cb_params)
|
||||
loss_cb = LossMonitor(1)
|
||||
callbacks = [loss_cb]
|
||||
callbacklist = _build_callbacks(callbacks)
|
||||
callbacklist.begin(run_context)
|
||||
callbacklist.epoch_begin(run_context)
|
||||
callbacklist.step_begin(run_context)
|
||||
callbacklist.step_end(run_context)
|
||||
callbacklist.epoch_end(run_context)
|
||||
callbacklist.end(run_context)
|
||||
with _CallbackManager(callbacks) as callbacklist:
|
||||
callbacklist.begin(run_context)
|
||||
callbacklist.epoch_begin(run_context)
|
||||
callbacklist.step_begin(run_context)
|
||||
callbacklist.step_end(run_context)
|
||||
callbacklist.epoch_end(run_context)
|
||||
callbacklist.end(run_context)
|
||||
|
||||
|
||||
def test_loss_monitor_normal_mode():
|
||||
|
@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds():
|
|||
ckpt_cb2.step_end(run_context)
|
||||
|
||||
|
||||
def test_build_callbacks():
|
||||
"""Test_build_callbacks."""
|
||||
def test_CallbackManager():
|
||||
"""TestCallbackManager."""
|
||||
ck_obj = ModelCheckpoint()
|
||||
loss_cb_1 = LossMonitor(1)
|
||||
|
||||
callbacks = [None]
|
||||
with pytest.raises(TypeError):
|
||||
callbacks = _build_callbacks(callbacks)
|
||||
_CallbackManager(callbacks)
|
||||
|
||||
callbacks = ['Error']
|
||||
with pytest.raises(TypeError):
|
||||
callbacks = _build_callbacks(callbacks)
|
||||
_CallbackManager(callbacks)
|
||||
|
||||
callbacks = [ck_obj, loss_cb_1, 'Error', None]
|
||||
with pytest.raises(TypeError):
|
||||
_ = _build_callbacks(callbacks)
|
||||
_CallbackManager(callbacks)
|
||||
|
||||
|
||||
def test_CallbackManager_exit_called():
|
||||
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
|
||||
cb1, cb2 = Callback(), Callback()
|
||||
with _CallbackManager([cb1, cb2]):
|
||||
pass
|
||||
for call_args in mock_exit.call_args_list:
|
||||
assert call_args == mock.call(mock.ANY, None, None, None)
|
||||
assert mock_exit.call_count == 2
|
||||
|
||||
|
||||
def test_CallbackManager_exit_called_when_raises():
|
||||
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
|
||||
cb1, cb2 = Callback(), Callback()
|
||||
with pytest.raises(ValueError):
|
||||
with _CallbackManager([cb1, cb2]):
|
||||
raise ValueError()
|
||||
for call_args in mock_exit.call_args_list:
|
||||
assert call_args == mock.call(*[mock.ANY] * 4)
|
||||
assert mock_exit.call_count == 2
|
||||
|
||||
|
||||
def test_CallbackManager_begin_called():
|
||||
context = dict()
|
||||
with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
|
||||
cb1, cb2 = Callback(), Callback()
|
||||
with _CallbackManager([cb1, cb2]) as cm:
|
||||
cm.begin(context)
|
||||
for call_args in mock_begin.call_args_list:
|
||||
assert call_args == mock.call(context)
|
||||
assert mock_begin.call_count == 2
|
||||
|
||||
|
||||
def test_RunContext():
|
||||
"""Test RunContext."""
|
||||
context_err = 666
|
||||
with pytest.raises(TypeError):
|
||||
_ = RunContext(context_err)
|
||||
RunContext(context_err)
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.member1 = 1
|
||||
|
|
Loading…
Reference in New Issue