diff --git a/example/resnet50_imagenet2012_THOR/model/model_thor.py b/example/resnet50_imagenet2012_THOR/model/model_thor.py index 3106b044530..25e3dd7f823 100644 --- a/example/resnet50_imagenet2012_THOR/model/model_thor.py +++ b/example/resnet50_imagenet2012_THOR/model/model_thor.py @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check from mindspore.train import amp -from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager +from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.parallel_utils import ParallelMode from model.dataset_helper import DatasetHelper diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index 4f21002470d..427cc5e568c 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -26,9 +26,9 @@ namespace mindspore { namespace callbacks { -const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback"; -const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; -const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; +const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback"; +const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op"; +const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op"; const char kSummary[] = "Summary"; const char kCheckPoint[] = "Save"; const int ONE_SHAPE = 1; diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index f45e0d5955c..3174ec4b151 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -25,9 +25,9 @@ namespace mindspore { namespace callbacks { -const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback"; -const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; -const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; +const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback"; +const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op"; +const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op"; const char kSummary[] = "Summary"; const char kCheckPoint[] = "Save"; const int ONE_SHAPE = 1; diff --git a/mindspore/train/callback/__init__.py b/mindspore/train/callback/__init__.py index 1f81f0de414..4e0dd729fb9 100644 --- a/mindspore/train/callback/__init__.py +++ b/mindspore/train/callback/__init__.py @@ -14,7 +14,15 @@ # ============================================================================ """Callback related classes and functions.""" -from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext +from ._callback import Callback +from ._callback import CallbackManager as _CallbackManager +from ._callback import InternalCallbackParam as _InternalCallbackParam +from ._callback import RunContext +from ._checkpoint import CheckpointConfig +from ._checkpoint import CheckpointManager as _CheckpointManager +from ._checkpoint import ModelCheckpoint +from ._loss_monitor import LossMonitor +from ._summary_step import SummaryStep +from ._time_monitor import TimeMonitor -__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", - "SummaryStep", "CheckpointConfig", "RunContext"] +__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"] diff --git a/mindspore/train/callback/_callback.py b/mindspore/train/callback/_callback.py new file mode 100644 index 00000000000..756b9c71183 --- /dev/null +++ b/mindspore/train/callback/_callback.py @@ -0,0 +1,260 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Callback related classes and functions.""" + +from contextlib import ExitStack + +from mindspore import log as logger +from mindspore.train.serialization import _fill_param_into_net +from mindspore.train.summary.summary_record import _cache_summary_tensor_data + +_cur_net = None + +def set_cur_net(net): + """ + Set current net for which we are using to save checkpoint. + + Args: + net (Cell): train network + """ + global _cur_net + _cur_net = net + + +def checkpoint_cb_for_save_op(parameter_list): + """ + The checkpoint callback function for MindSpore. + + Will be executed by checkpoint save op. + + Args: + parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor. + + Returns: + bool, true: means save checkpoint success. + """ + if _cur_net is None: + logger.warning("_cur_net is None. parameters are not updated.") + return False + + logger.info("update parameters in the net.") + _fill_param_into_net(_cur_net, parameter_list) + set_cur_net(None) + return True + + +def summary_cb_for_save_op(summary_list): + """ + The summary callback function for MindSpore. + + Will be executed by summary op. + + Args: + summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor. + + Returns: + bool, true: means save summary success. + """ + ret = _cache_summary_tensor_data(summary_list) + return ret + + +class Callback: + """ + Abstract base class used to build a callback class. Callbacks are context managers + which will be entered and exited when passing into the Model. + You can leverage this mechanism to init and release resources automatically. + + Callback function will execution some operating to the current step or epoch. + + Examples: + >>> class Print_info(Callback): + >>> def step_end(self, run_context): + >>> cb_params = run_context.original_args() + >>> print(cb_params.cur_epoch_num) + >>> print(cb_params.cur_step_num) + >>> + >>> print_cb = Print_info() + >>> model.train(epoch, dataset, callbacks=print_cb) + """ + + def __enter__(self): + """Return the enter target.""" + return self + + def __exit__(self, *err): + """Release resources here if have any.""" + + def begin(self, run_context): + """ + Called once before the network executing. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def epoch_begin(self, run_context): + """ + Called before each epoch beginning. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def epoch_end(self, run_context): + """ + Called after each epoch finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def step_begin(self, run_context): + """ + Called before each epoch beginning. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def step_end(self, run_context): + """ + Called after each step finished. + + Args: + run_context (RunContext): Include some information of the model. + """ + + def end(self, run_context): + """ + Called once after network training. + + Args: + run_context (RunContext): Include some information of the model. + """ + + +class CallbackManager(Callback): + """ + Sequential execution of callback functions. + + Execute Callback functions at certain points. + + Args: + callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. + """ + + def __init__(self, callbacks): + self._callbacks, self._stack = [], None + if isinstance(callbacks, Callback): + self._callbacks.append(callbacks) + elif callbacks is not None: + for cb in callbacks: + if not isinstance(cb, Callback): + raise TypeError("%r is not an instance of %r" % (cb, Callback)) + self._callbacks.append(cb) + + def __enter__(self): + if self._stack is None: + self._stack = ExitStack().__enter__() + self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] + return self + + def __exit__(self, *err): + return self._stack.__exit__(*err) + + def begin(self, run_context): + """Called once before network training.""" + for cb in self._callbacks: + cb.begin(run_context) + + def epoch_begin(self, run_context): + """Called before each epoch begin.""" + for cb in self._callbacks: + cb.epoch_begin(run_context) + + def epoch_end(self, run_context): + """Called after each epoch finished.""" + for cb in self._callbacks: + cb.epoch_end(run_context) + + def step_begin(self, run_context): + """Called before each epoch begin.""" + for cb in self._callbacks: + cb.step_begin(run_context) + + def step_end(self, run_context): + """Called after each step finished.""" + for cb in self._callbacks: + cb.step_end(run_context) + + def end(self, run_context): + """Called once after network training.""" + for cb in self._callbacks: + cb.end(run_context) + + +class InternalCallbackParam(dict): + """Internal callback object's parameters.""" + + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + +class RunContext: + """ + Provides information about the model. + + Run call being made. Provides information about original request to model function. + callback objects can stop the loop by calling request_stop() of run_context. + + Args: + original_args (dict): Holding the related information of model etc. + """ + def __init__(self, original_args): + if not isinstance(original_args, dict): + raise TypeError("The arg of RunContext should be dict type.") + self._original_args = original_args + self._stop_requested = False + + def original_args(self): + """ + Get the _original_args object. + + Returns: + Dict, a object holding the original arguments of model. + """ + return self._original_args + + def request_stop(self): + """ + Sets stop requested during training. + + Callbacks can use this function to request stop of iterations. + model.train() checks whether this is called or not. + """ + self._stop_requested = True + + def get_stop_requested(self): + """ + Returns whether a stop is requested or not. + + Returns: + bool, if true, model.train() stops iterations. + """ + return self._stop_requested diff --git a/mindspore/train/callback/callback.py b/mindspore/train/callback/_checkpoint.py similarity index 58% rename from mindspore/train/callback/callback.py rename to mindspore/train/callback/_checkpoint.py index 822fa3cb668..d185377c83c 100644 --- a/mindspore/train/callback/callback.py +++ b/mindspore/train/callback/_checkpoint.py @@ -12,93 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Callback related classes and functions.""" +"""Checkpoint related classes and functions.""" import os -import stat import shutil +import stat import time -from contextlib import ExitStack -import numpy as np import mindspore.context as context -from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph -from mindspore.train._utils import _make_directory from mindspore import log as logger -from mindspore._checkparam import check_int_non_negative, check_bool -from mindspore.common.tensor import Tensor -from mindspore.train.summary.summary_record import _cache_summary_tensor_data +from mindspore._checkparam import check_bool, check_int_non_negative +from mindspore.train._utils import _make_directory +from mindspore.train.serialization import _exec_save_checkpoint, _save_graph +from ._callback import Callback, set_cur_net _cur_dir = os.getcwd() -_cur_net = None _save_dir = _cur_dir -class _CheckpointManager: - """Manage checkpoint files according to train_config of checkpoint.""" - def __init__(self): - self._ckpoint_filelist = [] - - @property - def ckpoint_filelist(self): - """Get all the related checkpoint files managed here.""" - return self._ckpoint_filelist - - @property - def ckpoint_num(self): - """Get the number of the related checkpoint files managed here.""" - return len(self._ckpoint_filelist) - - def update_ckpoint_filelist(self, directory, prefix): - """Update the checkpoint file list.""" - self._ckpoint_filelist = [] - files = os.listdir(directory) - for filename in files: - if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix): - mid_name = filename[len(prefix):-5] - flag = True - for char in mid_name: - if char.isalpha(): - flag = False - if flag: - self._ckpoint_filelist.append(directory + '/' + filename) - - def remove_ckpoint_file(self, file_name): - """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" - try: - os.chmod(file_name, stat.S_IWRITE) - os.remove(file_name) - self._ckpoint_filelist.remove(file_name) - except OSError: - logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) - except ValueError: - logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) - - def remove_oldest_ckpoint_file(self): - """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" - ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) - self.remove_ckpoint_file(ckpoint_files[0]) - - def keep_one_ckpoint_per_minutes(self, minutes, cur_time): - """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" - movs = [] - oldest_file = '' - oldest_time = cur_time - for ck_file in self._ckpoint_filelist: - modify_time = os.path.getmtime(ck_file) - if cur_time - modify_time < 60 * minutes: - movs.append(ck_file) - - if modify_time < oldest_time: - oldest_time = modify_time - oldest_file = ck_file - - for mv_file in movs: - if mv_file == oldest_file: - continue - self.remove_ckpoint_file(mv_file) - def _check_file_name_prefix(file_name_prefix): """ @@ -234,282 +166,6 @@ class CheckpointConfig: return checkpoint_policy -def _set_cur_net(net): - """ - Set current net for which we are using to save checkpoint. - - Args: - net (Cell): train network - """ - global _cur_net - _cur_net = net - - -def _checkpoint_cb_for_save_op(parameter_list): - """ - The checkpoint callback function for MindSpore. - - Will be executed by checkpoint save op. - - Args: - parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor. - - Returns: - bool, true: means save checkpoint success. - """ - if _cur_net is None: - logger.warning("_cur_net is None. parameters are not updated.") - return False - - logger.info("update parameters in the net.") - _fill_param_into_net(_cur_net, parameter_list) - _set_cur_net(None) - return True - - -def _summary_cb_for_save_op(summary_list): - """ - The summary callback function for MindSpore. - - Will be executed by summary op. - - Args: - summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor. - - Returns: - bool, true: means save summary success. - """ - ret = _cache_summary_tensor_data(summary_list) - return ret - - -class Callback: - """ - Abstract base class used to build a callback class. Callbacks are context managers - which will be entered and exited when passing into the Model. - You can leverage this mechanism to init and release resources automatically. - - Callback function will execution some operating to the current step or epoch. - - Examples: - >>> class Print_info(Callback): - >>> def step_end(self, run_context): - >>> cb_params = run_context.original_args() - >>> print(cb_params.cur_epoch_num) - >>> print(cb_params.cur_step_num) - >>> - >>> print_cb = Print_info() - >>> model.train(epoch, dataset, callbacks=print_cb) - """ - - def __enter__(self): - """Return the enter target.""" - return self - - def __exit__(self, *err): - """Release resources here if have any.""" - - def begin(self, run_context): - """ - Called once before the network executing. - - Args: - run_context (RunContext): Include some information of the model. - """ - - def epoch_begin(self, run_context): - """ - Called before each epoch beginning. - - Args: - run_context (RunContext): Include some information of the model. - """ - - def epoch_end(self, run_context): - """ - Called after each epoch finished. - - Args: - run_context (RunContext): Include some information of the model. - """ - - def step_begin(self, run_context): - """ - Called before each epoch beginning. - - Args: - run_context (RunContext): Include some information of the model. - """ - - def step_end(self, run_context): - """ - Called after each step finished. - - Args: - run_context (RunContext): Include some information of the model. - """ - - def end(self, run_context): - """ - Called once after network training. - - Args: - run_context (RunContext): Include some information of the model. - """ - - -class _CallbackManager(Callback): - """ - Sequential execution of callback functions. - - Execute Callback functions at certain points. - - Args: - callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. - """ - - def __init__(self, callbacks): - self._callbacks, self._stack = [], None - if isinstance(callbacks, Callback): - self._callbacks.append(callbacks) - elif callbacks is not None: - for cb in callbacks: - if not isinstance(cb, Callback): - raise TypeError("%r is not an instance of %r" % (cb, Callback)) - self._callbacks.append(cb) - - def __enter__(self): - if self._stack is None: - self._stack = ExitStack().__enter__() - self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks] - return self - - def __exit__(self, *err): - return self._stack.__exit__(*err) - - def begin(self, run_context): - """Called once before network training.""" - for cb in self._callbacks: - cb.begin(run_context) - - def epoch_begin(self, run_context): - """Called before each epoch begin.""" - for cb in self._callbacks: - cb.epoch_begin(run_context) - - def epoch_end(self, run_context): - """Called after each epoch finished.""" - for cb in self._callbacks: - cb.epoch_end(run_context) - - def step_begin(self, run_context): - """Called before each epoch begin.""" - for cb in self._callbacks: - cb.step_begin(run_context) - - def step_end(self, run_context): - """Called after each step finished.""" - for cb in self._callbacks: - cb.step_end(run_context) - - def end(self, run_context): - """Called once after network training.""" - for cb in self._callbacks: - cb.end(run_context) - - - -class SummaryStep(Callback): - """ - The summary callback class. - - Args: - summary (Object): Summary recode object. - flush_step (int): Number of interval steps to execute. Default: 10. - """ - def __init__(self, summary, flush_step=10): - super(SummaryStep, self).__init__() - if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: - raise ValueError("`flush_step` should be int and greater than 0") - self._summary = summary - self._flush_step = flush_step - def __enter__(self): - self._summary.__enter__() - return self - - def __exit__(self, *err): - return self._summary.__exit__(*err) - - - def step_end(self, run_context): - """ - Save summary. - - Args: - run_context (RunContext): Context of the train running. - """ - cb_params = run_context.original_args() - if cb_params.cur_step_num % self._flush_step == 0: - self._summary.record(cb_params.cur_step_num, cb_params.train_network) - - @property - def summary_file_name(self): - return self._summary.full_file_name - - -class _InternalCallbackParam(dict): - """Internal callback object's parameters.""" - - def __getattr__(self, key): - return self[key] - - def __setattr__(self, key, value): - self[key] = value - - -class RunContext: - """ - Provides information about the model. - - Run call being made. Provides information about original request to model function. - callback objects can stop the loop by calling request_stop() of run_context. - - Args: - original_args (dict): Holding the related information of model etc. - """ - def __init__(self, original_args): - if not isinstance(original_args, dict): - raise TypeError("The arg of RunContext should be dict type.") - self._original_args = original_args - self._stop_requested = False - - def original_args(self): - """ - Get the _original_args object. - - Returns: - Dict, a object holding the original arguments of model. - """ - return self._original_args - - def request_stop(self): - """ - Sets stop requested during training. - - Callbacks can use this function to request stop of iterations. - model.train() checks whether this is called or not. - """ - self._stop_requested = True - - def get_stop_requested(self): - """ - Returns whether a stop is requested or not. - - Returns: - bool, if true, model.train() stops iterations. - """ - return self._stop_requested - class ModelCheckpoint(Callback): """ @@ -553,7 +209,7 @@ class ModelCheckpoint(Callback): self._config = config # get existing checkpoint files - self._manager = _CheckpointManager() + self._manager = CheckpointManager() self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._graph_saved = False @@ -633,7 +289,7 @@ class ModelCheckpoint(Callback): self._last_triggered_step = cb_params.cur_step_num if context.get_context("enable_ge"): - _set_cur_net(cb_params.train_network) + set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) @@ -648,57 +304,66 @@ class ModelCheckpoint(Callback): return self._latest_ckpt_file_name -class LossMonitor(Callback): - """ - Monitor the loss in training. +class CheckpointManager: + """Manage checkpoint files according to train_config of checkpoint.""" + def __init__(self): + self._ckpoint_filelist = [] - If the loss is NAN or INF, it will terminate training. + @property + def ckpoint_filelist(self): + """Get all the related checkpoint files managed here.""" + return self._ckpoint_filelist - Note: - If per_print_times is 0 do not print loss. + @property + def ckpoint_num(self): + """Get the number of the related checkpoint files managed here.""" + return len(self._ckpoint_filelist) - Args: - per_print_times (int): Print loss every times. Default: 1. + def update_ckpoint_filelist(self, directory, prefix): + """Update the checkpoint file list.""" + self._ckpoint_filelist = [] + files = os.listdir(directory) + for filename in files: + if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix): + mid_name = filename[len(prefix):-5] + flag = True + for char in mid_name: + if char.isalpha(): + flag = False + if flag: + self._ckpoint_filelist.append(directory + '/' + filename) - Raises: - ValueError: If print_step is not int or less than zero. - """ - def __init__(self, per_print_times=1): - super(LossMonitor, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0.") - self._per_print_times = per_print_times + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + self._ckpoint_filelist.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) - def step_end(self, run_context): - cb_params = run_context.original_args() - loss = cb_params.net_outputs + def remove_oldest_ckpoint_file(self): + """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" + ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) + self.remove_ckpoint_file(ckpoint_files[0]) - if isinstance(loss, (tuple, list)): - if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): - loss = loss[0] + def keep_one_ckpoint_per_minutes(self, minutes, cur_time): + """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" + movs = [] + oldest_file = '' + oldest_time = cur_time + for ck_file in self._ckpoint_filelist: + modify_time = os.path.getmtime(ck_file) + if cur_time - modify_time < 60 * minutes: + movs.append(ck_file) - if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): - loss = np.mean(loss.asnumpy()) + if modify_time < oldest_time: + oldest_time = modify_time + oldest_file = ck_file - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 - - if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): - raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." - .format(cb_params.cur_epoch_num, cur_step_in_epoch)) - if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: - print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) - - -class TimeMonitor(Callback): - """Time Monitor.""" - def __init__(self, data_size): - super(TimeMonitor, self).__init__() - self.data_size = data_size - - def epoch_begin(self, run_context): - self.epoch_time = time.time() - - def epoch_end(self, run_context): - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / self.data_size - print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) + for mv_file in movs: + if mv_file == oldest_file: + continue + self.remove_ckpoint_file(mv_file) diff --git a/mindspore/train/callback/_loss_monitor.py b/mindspore/train/callback/_loss_monitor.py new file mode 100644 index 00000000000..15a095c5cb1 --- /dev/null +++ b/mindspore/train/callback/_loss_monitor.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LossMonitor Callback class.""" + +import numpy as np +from mindspore.common.tensor import Tensor + +from ._callback import Callback + + +class LossMonitor(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, it will terminate training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + + Raises: + ValueError: If print_step is not int or less than zero. + """ + + def __init__(self, per_print_times=1): + super(LossMonitor, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( + cb_params.cur_epoch_num, cur_step_in_epoch)) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) diff --git a/mindspore/train/callback/_summary_step.py b/mindspore/train/callback/_summary_step.py new file mode 100644 index 00000000000..0a4fbca80d5 --- /dev/null +++ b/mindspore/train/callback/_summary_step.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SummaryStep Callback class.""" + +from ._callback import Callback + + +class SummaryStep(Callback): + """ + The summary callback class. + + Args: + summary (Object): Summary recode object. + flush_step (int): Number of interval steps to execute. Default: 10. + """ + + def __init__(self, summary, flush_step=10): + super(SummaryStep, self).__init__() + if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: + raise ValueError("`flush_step` should be int and greater than 0") + self._summary = summary + self._flush_step = flush_step + + def __enter__(self): + self._summary.__enter__() + return self + + def __exit__(self, *err): + return self._summary.__exit__(*err) + + def step_end(self, run_context): + """ + Save summary. + + Args: + run_context (RunContext): Context of the train running. + """ + cb_params = run_context.original_args() + if cb_params.cur_step_num % self._flush_step == 0: + self._summary.record(cb_params.cur_step_num, cb_params.train_network) + + @property + def summary_file_name(self): + return self._summary.full_file_name diff --git a/mindspore/train/callback/_time_monitor.py b/mindspore/train/callback/_time_monitor.py new file mode 100644 index 00000000000..c810306d24d --- /dev/null +++ b/mindspore/train/callback/_time_monitor.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TimeMonitor Callback class.""" + +import time + +from ._callback import Callback + + +class TimeMonitor(Callback): + """Time Monitor.""" + + def __init__(self, data_size): + super(TimeMonitor, self).__init__() + self.data_size = data_size + + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / self.data_size + print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 68042d8d0ae..8288e533680 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -19,7 +19,7 @@ from mindspore import log as logger from ..common.tensor import Tensor from ..nn.metrics import get_metrics from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool -from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager +from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check diff --git a/tests/st/networks/models/resnet50/src_thor/model_thor.py b/tests/st/networks/models/resnet50/src_thor/model_thor.py index ee799c4b740..07b9e60bed9 100644 --- a/tests/st/networks/models/resnet50/src_thor/model_thor.py +++ b/tests/st/networks/models/resnet50/src_thor/model_thor.py @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check from mindspore.train import amp -from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager +from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.parallel_utils import ParallelMode from .dataset_helper import DatasetHelper diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index b0879ebc0ed..c4f6e0aa5b6 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -26,10 +26,10 @@ from mindspore.common.api import ms_function from mindspore.common.tensor import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum -from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ - _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ - _CallbackManager, Callback, CheckpointConfig, _set_cur_net - +from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ + _CallbackManager, Callback, CheckpointConfig +from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op +from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist class Net(nn.Cell): """Net definition.""" @@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op(): one_param['name'] = "conv1.weight" one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) parameter_list.append(one_param) - _checkpoint_cb_for_save_op(parameter_list) + checkpoint_cb_for_save_op(parameter_list) def test_checkpoint_cb_for_save_op_update_net(): @@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net(): one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) parameter_list.append(one_param) net = Net() - _set_cur_net(net) - _checkpoint_cb_for_save_op(parameter_list) + set_cur_net(net) + checkpoint_cb_for_save_op(parameter_list) assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1 diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index f5046bb1ec3..b312f9f7d12 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P -from mindspore.train.callback.callback import _CheckpointManager +from mindspore.train.callback import _CheckpointManager from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ _exec_save_checkpoint, export, _save_graph from ..ut_filter import non_graph_engine