forked from mindspore-Ecosystem/mindspore
!1276 Callbacks as context managers
Merge pull request !1276 from 李鸿章/context_manager
This commit is contained in:
commit
08a496d073
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
from mindspore.train import amp
|
from mindspore.train import amp
|
||||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
|
||||||
from model.dataset_helper import DatasetHelper
|
from model.dataset_helper import DatasetHelper
|
||||||
|
@ -374,7 +374,6 @@ class Model:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
# build callback list
|
# build callback list
|
||||||
list_callback = _build_callbacks(callbacks)
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.train_network = self._train_network
|
cb_params.train_network = self._train_network
|
||||||
cb_params.epoch_num = epoch
|
cb_params.epoch_num = epoch
|
||||||
|
@ -385,17 +384,17 @@ class Model:
|
||||||
cb_params.parallel_mode = self._parallel_mode
|
cb_params.parallel_mode = self._parallel_mode
|
||||||
cb_params.device_number = self._device_number
|
cb_params.device_number = self._device_number
|
||||||
cb_params.train_dataset = train_dataset
|
cb_params.train_dataset = train_dataset
|
||||||
cb_params.list_callback = list_callback
|
cb_params.list_callback = callbacks
|
||||||
|
|
||||||
if dataset_sink_mode:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
if not dataset_sink_mode:
|
||||||
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
|
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||||
"So the training process will be performed with dataset not sink.")
|
"So the training process will be performed with dataset not sink.")
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
else:
|
else:
|
||||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
else:
|
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
||||||
|
|
||||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||||
"""
|
"""
|
||||||
|
@ -408,7 +407,7 @@ class Model:
|
||||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||||
be returned, and the data and label are passed to the network and loss
|
be returned, and the data and label are passed to the network and loss
|
||||||
function respectively.
|
function respectively.
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
iter_first_order = self._frequency - 1
|
iter_first_order = self._frequency - 1
|
||||||
|
@ -473,7 +472,7 @@ class Model:
|
||||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||||
be returned, and the data and label are passed to the network and loss
|
be returned, and the data and label are passed to the network and loss
|
||||||
function respectively.
|
function respectively.
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||||
|
@ -580,7 +579,7 @@ class Model:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -619,7 +618,7 @@ class Model:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -678,7 +677,6 @@ class Model:
|
||||||
if not self._metric_fns:
|
if not self._metric_fns:
|
||||||
raise ValueError("metric fn can not be None or empty.")
|
raise ValueError("metric fn can not be None or empty.")
|
||||||
|
|
||||||
list_callback = _build_callbacks(callbacks)
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.eval_network = self._eval_network
|
cb_params.eval_network = self._eval_network
|
||||||
cb_params.valid_dataset = valid_dataset
|
cb_params.valid_dataset = valid_dataset
|
||||||
|
@ -691,9 +689,10 @@ class Model:
|
||||||
|
|
||||||
self._clear_metrics()
|
self._clear_metrics()
|
||||||
|
|
||||||
if dataset_sink_mode:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
if dataset_sink_mode:
|
||||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||||
|
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||||
|
|
||||||
def predict(self, *predict_data):
|
def predict(self, *predict_data):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -18,6 +18,7 @@ import os
|
||||||
import stat
|
import stat
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from contextlib import ExitStack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
|
@ -282,80 +283,11 @@ def _summary_cb_for_save_op(summary_list):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _build_callbacks(callbacks):
|
|
||||||
"""
|
|
||||||
Contain a list of callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List, a list of callback functions.
|
|
||||||
"""
|
|
||||||
if callbacks:
|
|
||||||
if isinstance(callbacks, tuple):
|
|
||||||
raise TypeError("Callbacks cannot be a tuple. Please check it.")
|
|
||||||
if not isinstance(callbacks, list):
|
|
||||||
callbacks = [callbacks]
|
|
||||||
else:
|
|
||||||
callbacks = []
|
|
||||||
|
|
||||||
excute_callbacks = []
|
|
||||||
for cb in callbacks:
|
|
||||||
if cb is None or not isinstance(cb, Callback):
|
|
||||||
raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.")
|
|
||||||
excute_callbacks.append(cb)
|
|
||||||
|
|
||||||
return _ListCallback(excute_callbacks)
|
|
||||||
|
|
||||||
|
|
||||||
class _ListCallback:
|
|
||||||
"""
|
|
||||||
Sequential execution of callback functions.
|
|
||||||
|
|
||||||
Execute Callback functions at certain points.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callbacks (list): Callback functions list.
|
|
||||||
"""
|
|
||||||
def __init__(self, callbacks):
|
|
||||||
super(_ListCallback, self).__init__()
|
|
||||||
self._callbacks = callbacks
|
|
||||||
|
|
||||||
def begin(self, run_context):
|
|
||||||
"""Called once before network training."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.begin(run_context)
|
|
||||||
|
|
||||||
def epoch_begin(self, run_context):
|
|
||||||
"""Called before each epoch begin."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.epoch_begin(run_context)
|
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
|
||||||
"""Called after each epoch finished."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.epoch_end(run_context)
|
|
||||||
|
|
||||||
def step_begin(self, run_context):
|
|
||||||
"""Called before each epoch begin."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.step_begin(run_context)
|
|
||||||
|
|
||||||
def step_end(self, run_context):
|
|
||||||
"""Called after each step finished."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.step_end(run_context)
|
|
||||||
|
|
||||||
def end(self, run_context):
|
|
||||||
"""Called once after network training."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.end(run_context)
|
|
||||||
|
|
||||||
|
|
||||||
class Callback:
|
class Callback:
|
||||||
"""
|
"""
|
||||||
Abstract base class used to build a callback function.
|
Abstract base class used to build a callback class. Callbacks are context managers
|
||||||
|
which will be entered and exited when passing into the Model.
|
||||||
|
You can leverage this mechanism to init and release resources automatically.
|
||||||
|
|
||||||
Callback function will execution some operating to the current step or epoch.
|
Callback function will execution some operating to the current step or epoch.
|
||||||
|
|
||||||
|
@ -369,8 +301,13 @@ class Callback:
|
||||||
>>> print_cb = Print_info()
|
>>> print_cb = Print_info()
|
||||||
>>> model.train(epoch, dataset, callbacks=print_cb)
|
>>> model.train(epoch, dataset, callbacks=print_cb)
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
|
||||||
pass
|
def __enter__(self):
|
||||||
|
"""Return the enter target."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
"""Release resources here if have any."""
|
||||||
|
|
||||||
def begin(self, run_context):
|
def begin(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -421,6 +358,67 @@ class Callback:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _CallbackManager(Callback):
|
||||||
|
"""
|
||||||
|
Sequential execution of callback functions.
|
||||||
|
|
||||||
|
Execute Callback functions at certain points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callbacks):
|
||||||
|
self._callbacks, self._stack = [], None
|
||||||
|
if isinstance(callbacks, Callback):
|
||||||
|
self._callbacks.append(callbacks)
|
||||||
|
elif callbacks is not None:
|
||||||
|
for cb in callbacks:
|
||||||
|
if not isinstance(cb, Callback):
|
||||||
|
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
||||||
|
self._callbacks.append(cb)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self._stack is None:
|
||||||
|
self._stack = ExitStack().__enter__()
|
||||||
|
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
return self._stack.__exit__(*err)
|
||||||
|
|
||||||
|
def begin(self, run_context):
|
||||||
|
"""Called once before network training."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.begin(run_context)
|
||||||
|
|
||||||
|
def epoch_begin(self, run_context):
|
||||||
|
"""Called before each epoch begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.epoch_begin(run_context)
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""Called after each epoch finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.epoch_end(run_context)
|
||||||
|
|
||||||
|
def step_begin(self, run_context):
|
||||||
|
"""Called before each epoch begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.step_begin(run_context)
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""Called after each step finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.step_end(run_context)
|
||||||
|
|
||||||
|
def end(self, run_context):
|
||||||
|
"""Called once after network training."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.end(run_context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SummaryStep(Callback):
|
class SummaryStep(Callback):
|
||||||
"""
|
"""
|
||||||
The summary callback class.
|
The summary callback class.
|
||||||
|
@ -435,6 +433,13 @@ class SummaryStep(Callback):
|
||||||
raise ValueError("`flush_step` should be int and greater than 0")
|
raise ValueError("`flush_step` should be int and greater than 0")
|
||||||
self._summary = summary
|
self._summary = summary
|
||||||
self._flush_step = flush_step
|
self._flush_step = flush_step
|
||||||
|
def __enter__(self):
|
||||||
|
self._summary.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
return self._summary.__exit__(*err)
|
||||||
|
|
||||||
|
|
||||||
def step_end(self, run_context):
|
def step_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -19,7 +19,7 @@ from mindspore import log as logger
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
from ..nn.metrics import get_metrics
|
from ..nn.metrics import get_metrics
|
||||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||||
from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from .. import context
|
from .. import context
|
||||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
|
@ -334,8 +334,6 @@ class Model:
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
# build callback list
|
|
||||||
list_callback = _build_callbacks(callbacks)
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.train_network = self._train_network
|
cb_params.train_network = self._train_network
|
||||||
cb_params.epoch_num = epoch
|
cb_params.epoch_num = epoch
|
||||||
|
@ -346,17 +344,18 @@ class Model:
|
||||||
cb_params.parallel_mode = self._parallel_mode
|
cb_params.parallel_mode = self._parallel_mode
|
||||||
cb_params.device_number = self._device_number
|
cb_params.device_number = self._device_number
|
||||||
cb_params.train_dataset = train_dataset
|
cb_params.train_dataset = train_dataset
|
||||||
cb_params.list_callback = list_callback
|
cb_params.list_callback = callbacks
|
||||||
|
|
||||||
if dataset_sink_mode:
|
# build callback list
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
|
if not dataset_sink_mode:
|
||||||
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
|
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||||
"So the training process will be performed with dataset not sink.")
|
"So the training process will be performed with dataset not sink.")
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
else:
|
else:
|
||||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
else:
|
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
||||||
|
|
||||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||||
"""
|
"""
|
||||||
|
@ -369,7 +368,7 @@ class Model:
|
||||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||||
be returned, and the data and label are passed to the network and loss
|
be returned, and the data and label are passed to the network and loss
|
||||||
function respectively.
|
function respectively.
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||||
|
@ -417,7 +416,7 @@ class Model:
|
||||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||||
be returned, and the data and label are passed to the network and loss
|
be returned, and the data and label are passed to the network and loss
|
||||||
function respectively.
|
function respectively.
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||||
|
@ -524,7 +523,7 @@ class Model:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -563,7 +562,7 @@ class Model:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -622,7 +621,6 @@ class Model:
|
||||||
if not self._metric_fns:
|
if not self._metric_fns:
|
||||||
raise ValueError("metric fn can not be None or empty.")
|
raise ValueError("metric fn can not be None or empty.")
|
||||||
|
|
||||||
list_callback = _build_callbacks(callbacks)
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.eval_network = self._eval_network
|
cb_params.eval_network = self._eval_network
|
||||||
cb_params.valid_dataset = valid_dataset
|
cb_params.valid_dataset = valid_dataset
|
||||||
|
@ -635,9 +633,10 @@ class Model:
|
||||||
|
|
||||||
self._clear_metrics()
|
self._clear_metrics()
|
||||||
|
|
||||||
if dataset_sink_mode:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
if dataset_sink_mode:
|
||||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||||
|
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||||
|
|
||||||
def predict(self, *predict_data):
|
def predict(self, *predict_data):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
from mindspore.train import amp
|
from mindspore.train import amp
|
||||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
|
||||||
from .dataset_helper import DatasetHelper
|
from .dataset_helper import DatasetHelper
|
||||||
|
@ -392,7 +392,6 @@ class Model:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
# build callback list
|
# build callback list
|
||||||
list_callback = _build_callbacks(callbacks)
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.train_network = self._train_network
|
cb_params.train_network = self._train_network
|
||||||
cb_params.epoch_num = epoch
|
cb_params.epoch_num = epoch
|
||||||
|
@ -403,17 +402,17 @@ class Model:
|
||||||
cb_params.parallel_mode = self._parallel_mode
|
cb_params.parallel_mode = self._parallel_mode
|
||||||
cb_params.device_number = self._device_number
|
cb_params.device_number = self._device_number
|
||||||
cb_params.train_dataset = train_dataset
|
cb_params.train_dataset = train_dataset
|
||||||
cb_params.list_callback = list_callback
|
cb_params.list_callback = callbacks
|
||||||
|
|
||||||
if dataset_sink_mode:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
if not dataset_sink_mode:
|
||||||
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
|
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||||
"So the training process will be performed with dataset not sink.")
|
"So the training process will be performed with dataset not sink.")
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
else:
|
else:
|
||||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||||
else:
|
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
|
||||||
|
|
||||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||||
"""
|
"""
|
||||||
|
@ -426,7 +425,7 @@ class Model:
|
||||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||||
be returned, and the data and label are passed to the network and loss
|
be returned, and the data and label are passed to the network and loss
|
||||||
function respectively.
|
function respectively.
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
iter_first_order = self._frequency - 1
|
iter_first_order = self._frequency - 1
|
||||||
|
@ -490,7 +489,7 @@ class Model:
|
||||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||||
be returned, and the data and label are passed to the network and loss
|
be returned, and the data and label are passed to the network and loss
|
||||||
function respectively.
|
function respectively.
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (Callback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||||
|
@ -695,7 +694,6 @@ class Model:
|
||||||
if not self._metric_fns:
|
if not self._metric_fns:
|
||||||
raise ValueError("metric fn can not be None or empty.")
|
raise ValueError("metric fn can not be None or empty.")
|
||||||
|
|
||||||
list_callback = _build_callbacks(callbacks)
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.eval_network = self._eval_network
|
cb_params.eval_network = self._eval_network
|
||||||
cb_params.valid_dataset = valid_dataset
|
cb_params.valid_dataset = valid_dataset
|
||||||
|
@ -708,9 +706,10 @@ class Model:
|
||||||
|
|
||||||
self._clear_metrics()
|
self._clear_metrics()
|
||||||
|
|
||||||
if dataset_sink_mode:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
if dataset_sink_mode:
|
||||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||||
|
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||||
|
|
||||||
def predict(self, *predict_data):
|
def predict(self, *predict_data):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -156,12 +156,19 @@ def get_dataset():
|
||||||
|
|
||||||
|
|
||||||
class ImageSummaryCallback:
|
class ImageSummaryCallback:
|
||||||
def __init__(self, summaryRecord):
|
|
||||||
self._summaryRecord = summaryRecord
|
def __init__(self, summary_record):
|
||||||
|
self._summary_record = summary_record
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
pass
|
||||||
|
|
||||||
def record(self, step, train_network=None):
|
def record(self, step, train_network=None):
|
||||||
self._summaryRecord.record(step, train_network)
|
self._summary_record.record(step, train_network)
|
||||||
self._summaryRecord.flush()
|
self._summary_record.flush()
|
||||||
|
|
||||||
|
|
||||||
def test_image_summary_train():
|
def test_image_summary_train():
|
||||||
|
|
|
@ -180,6 +180,12 @@ class CallbackTest:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
pass
|
||||||
|
|
||||||
def record(self, step, *args):
|
def record(self, step, *args):
|
||||||
print(step, args)
|
print(step, args)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""test callback function."""
|
"""test callback function."""
|
||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -27,7 +28,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
from mindspore.nn.optim import Momentum
|
from mindspore.nn.optim import Momentum
|
||||||
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
|
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
|
||||||
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
||||||
_build_callbacks, CheckpointConfig, _set_cur_net
|
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
@ -122,13 +123,13 @@ def test_loss_monitor_sink_mode():
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
loss_cb = LossMonitor(1)
|
loss_cb = LossMonitor(1)
|
||||||
callbacks = [loss_cb]
|
callbacks = [loss_cb]
|
||||||
callbacklist = _build_callbacks(callbacks)
|
with _CallbackManager(callbacks) as callbacklist:
|
||||||
callbacklist.begin(run_context)
|
callbacklist.begin(run_context)
|
||||||
callbacklist.epoch_begin(run_context)
|
callbacklist.epoch_begin(run_context)
|
||||||
callbacklist.step_begin(run_context)
|
callbacklist.step_begin(run_context)
|
||||||
callbacklist.step_end(run_context)
|
callbacklist.step_end(run_context)
|
||||||
callbacklist.epoch_end(run_context)
|
callbacklist.epoch_end(run_context)
|
||||||
callbacklist.end(run_context)
|
callbacklist.end(run_context)
|
||||||
|
|
||||||
|
|
||||||
def test_loss_monitor_normal_mode():
|
def test_loss_monitor_normal_mode():
|
||||||
|
@ -269,29 +270,61 @@ def test_checkpoint_save_ckpt_seconds():
|
||||||
ckpt_cb2.step_end(run_context)
|
ckpt_cb2.step_end(run_context)
|
||||||
|
|
||||||
|
|
||||||
def test_build_callbacks():
|
def test_CallbackManager():
|
||||||
"""Test_build_callbacks."""
|
"""TestCallbackManager."""
|
||||||
ck_obj = ModelCheckpoint()
|
ck_obj = ModelCheckpoint()
|
||||||
loss_cb_1 = LossMonitor(1)
|
loss_cb_1 = LossMonitor(1)
|
||||||
|
|
||||||
callbacks = [None]
|
callbacks = [None]
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
callbacks = _build_callbacks(callbacks)
|
_CallbackManager(callbacks)
|
||||||
|
|
||||||
callbacks = ['Error']
|
callbacks = ['Error']
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
callbacks = _build_callbacks(callbacks)
|
_CallbackManager(callbacks)
|
||||||
|
|
||||||
callbacks = [ck_obj, loss_cb_1, 'Error', None]
|
callbacks = [ck_obj, loss_cb_1, 'Error', None]
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = _build_callbacks(callbacks)
|
_CallbackManager(callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_CallbackManager_exit_called():
|
||||||
|
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
|
||||||
|
cb1, cb2 = Callback(), Callback()
|
||||||
|
with _CallbackManager([cb1, cb2]):
|
||||||
|
pass
|
||||||
|
for call_args in mock_exit.call_args_list:
|
||||||
|
assert call_args == mock.call(mock.ANY, None, None, None)
|
||||||
|
assert mock_exit.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_CallbackManager_exit_called_when_raises():
|
||||||
|
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
|
||||||
|
cb1, cb2 = Callback(), Callback()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
with _CallbackManager([cb1, cb2]):
|
||||||
|
raise ValueError()
|
||||||
|
for call_args in mock_exit.call_args_list:
|
||||||
|
assert call_args == mock.call(*[mock.ANY] * 4)
|
||||||
|
assert mock_exit.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_CallbackManager_begin_called():
|
||||||
|
context = dict()
|
||||||
|
with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
|
||||||
|
cb1, cb2 = Callback(), Callback()
|
||||||
|
with _CallbackManager([cb1, cb2]) as cm:
|
||||||
|
cm.begin(context)
|
||||||
|
for call_args in mock_begin.call_args_list:
|
||||||
|
assert call_args == mock.call(context)
|
||||||
|
assert mock_begin.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_RunContext():
|
def test_RunContext():
|
||||||
"""Test RunContext."""
|
"""Test RunContext."""
|
||||||
context_err = 666
|
context_err = 666
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_ = RunContext(context_err)
|
RunContext(context_err)
|
||||||
|
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.member1 = 1
|
cb_params.member1 = 1
|
||||||
|
|
Loading…
Reference in New Issue