forked from OSSInnovation/mindspore
!2236 Refactor the callback module in an encapsulated way
Merge pull request !2236 from 李鸿章/callback
This commit is contained in:
commit
0c3d96a98b
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from model.dataset_helper import DatasetHelper
|
||||
|
|
|
@ -26,9 +26,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace callbacks {
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
|
||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
|
||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
|
||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
|
||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
|
||||
const char kSummary[] = "Summary";
|
||||
const char kCheckPoint[] = "Save";
|
||||
const int ONE_SHAPE = 1;
|
||||
|
|
|
@ -25,9 +25,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace callbacks {
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
|
||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
|
||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
|
||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
|
||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
|
||||
const char kSummary[] = "Summary";
|
||||
const char kCheckPoint[] = "Save";
|
||||
const int ONE_SHAPE = 1;
|
||||
|
|
|
@ -14,7 +14,15 @@
|
|||
# ============================================================================
|
||||
"""Callback related classes and functions."""
|
||||
|
||||
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext
|
||||
from ._callback import Callback
|
||||
from ._callback import CallbackManager as _CallbackManager
|
||||
from ._callback import InternalCallbackParam as _InternalCallbackParam
|
||||
from ._callback import RunContext
|
||||
from ._checkpoint import CheckpointConfig
|
||||
from ._checkpoint import CheckpointManager as _CheckpointManager
|
||||
from ._checkpoint import ModelCheckpoint
|
||||
from ._loss_monitor import LossMonitor
|
||||
from ._summary_step import SummaryStep
|
||||
from ._time_monitor import TimeMonitor
|
||||
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
||||
"SummaryStep", "CheckpointConfig", "RunContext"]
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Callback related classes and functions."""
|
||||
|
||||
from contextlib import ExitStack
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.train.serialization import _fill_param_into_net
|
||||
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
|
||||
|
||||
_cur_net = None
|
||||
|
||||
def set_cur_net(net):
|
||||
"""
|
||||
Set current net for which we are using to save checkpoint.
|
||||
|
||||
Args:
|
||||
net (Cell): train network
|
||||
"""
|
||||
global _cur_net
|
||||
_cur_net = net
|
||||
|
||||
|
||||
def checkpoint_cb_for_save_op(parameter_list):
|
||||
"""
|
||||
The checkpoint callback function for MindSpore.
|
||||
|
||||
Will be executed by checkpoint save op.
|
||||
|
||||
Args:
|
||||
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
|
||||
|
||||
Returns:
|
||||
bool, true: means save checkpoint success.
|
||||
"""
|
||||
if _cur_net is None:
|
||||
logger.warning("_cur_net is None. parameters are not updated.")
|
||||
return False
|
||||
|
||||
logger.info("update parameters in the net.")
|
||||
_fill_param_into_net(_cur_net, parameter_list)
|
||||
set_cur_net(None)
|
||||
return True
|
||||
|
||||
|
||||
def summary_cb_for_save_op(summary_list):
|
||||
"""
|
||||
The summary callback function for MindSpore.
|
||||
|
||||
Will be executed by summary op.
|
||||
|
||||
Args:
|
||||
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
|
||||
|
||||
Returns:
|
||||
bool, true: means save summary success.
|
||||
"""
|
||||
ret = _cache_summary_tensor_data(summary_list)
|
||||
return ret
|
||||
|
||||
|
||||
class Callback:
|
||||
"""
|
||||
Abstract base class used to build a callback class. Callbacks are context managers
|
||||
which will be entered and exited when passing into the Model.
|
||||
You can leverage this mechanism to init and release resources automatically.
|
||||
|
||||
Callback function will execution some operating to the current step or epoch.
|
||||
|
||||
Examples:
|
||||
>>> class Print_info(Callback):
|
||||
>>> def step_end(self, run_context):
|
||||
>>> cb_params = run_context.original_args()
|
||||
>>> print(cb_params.cur_epoch_num)
|
||||
>>> print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> print_cb = Print_info()
|
||||
>>> model.train(epoch, dataset, callbacks=print_cb)
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
"""Return the enter target."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
"""Release resources here if have any."""
|
||||
|
||||
def begin(self, run_context):
|
||||
"""
|
||||
Called once before the network executing.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""
|
||||
Called before each epoch beginning.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Called after each epoch finished.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""
|
||||
Called before each epoch beginning.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Called after each step finished.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Called once after network training.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
|
||||
class CallbackManager(Callback):
|
||||
"""
|
||||
Sequential execution of callback functions.
|
||||
|
||||
Execute Callback functions at certain points.
|
||||
|
||||
Args:
|
||||
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks):
|
||||
self._callbacks, self._stack = [], None
|
||||
if isinstance(callbacks, Callback):
|
||||
self._callbacks.append(callbacks)
|
||||
elif callbacks is not None:
|
||||
for cb in callbacks:
|
||||
if not isinstance(cb, Callback):
|
||||
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
||||
self._callbacks.append(cb)
|
||||
|
||||
def __enter__(self):
|
||||
if self._stack is None:
|
||||
self._stack = ExitStack().__enter__()
|
||||
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._stack.__exit__(*err)
|
||||
|
||||
def begin(self, run_context):
|
||||
"""Called once before network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.begin(run_context)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_begin(run_context)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Called after each epoch finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_end(run_context)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_begin(run_context)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Called after each step finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_end(run_context)
|
||||
|
||||
def end(self, run_context):
|
||||
"""Called once after network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.end(run_context)
|
||||
|
||||
|
||||
class InternalCallbackParam(dict):
|
||||
"""Internal callback object's parameters."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
|
||||
class RunContext:
|
||||
"""
|
||||
Provides information about the model.
|
||||
|
||||
Run call being made. Provides information about original request to model function.
|
||||
callback objects can stop the loop by calling request_stop() of run_context.
|
||||
|
||||
Args:
|
||||
original_args (dict): Holding the related information of model etc.
|
||||
"""
|
||||
def __init__(self, original_args):
|
||||
if not isinstance(original_args, dict):
|
||||
raise TypeError("The arg of RunContext should be dict type.")
|
||||
self._original_args = original_args
|
||||
self._stop_requested = False
|
||||
|
||||
def original_args(self):
|
||||
"""
|
||||
Get the _original_args object.
|
||||
|
||||
Returns:
|
||||
Dict, a object holding the original arguments of model.
|
||||
"""
|
||||
return self._original_args
|
||||
|
||||
def request_stop(self):
|
||||
"""
|
||||
Sets stop requested during training.
|
||||
|
||||
Callbacks can use this function to request stop of iterations.
|
||||
model.train() checks whether this is called or not.
|
||||
"""
|
||||
self._stop_requested = True
|
||||
|
||||
def get_stop_requested(self):
|
||||
"""
|
||||
Returns whether a stop is requested or not.
|
||||
|
||||
Returns:
|
||||
bool, if true, model.train() stops iterations.
|
||||
"""
|
||||
return self._stop_requested
|
|
@ -12,93 +12,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Callback related classes and functions."""
|
||||
"""Checkpoint related classes and functions."""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import shutil
|
||||
import stat
|
||||
import time
|
||||
from contextlib import ExitStack
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_int_non_negative, check_bool
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
|
||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||
|
||||
from ._callback import Callback, set_cur_net
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
_cur_net = None
|
||||
_save_dir = _cur_dir
|
||||
|
||||
|
||||
class _CheckpointManager:
|
||||
"""Manage checkpoint files according to train_config of checkpoint."""
|
||||
def __init__(self):
|
||||
self._ckpoint_filelist = []
|
||||
|
||||
@property
|
||||
def ckpoint_filelist(self):
|
||||
"""Get all the related checkpoint files managed here."""
|
||||
return self._ckpoint_filelist
|
||||
|
||||
@property
|
||||
def ckpoint_num(self):
|
||||
"""Get the number of the related checkpoint files managed here."""
|
||||
return len(self._ckpoint_filelist)
|
||||
|
||||
def update_ckpoint_filelist(self, directory, prefix):
|
||||
"""Update the checkpoint file list."""
|
||||
self._ckpoint_filelist = []
|
||||
files = os.listdir(directory)
|
||||
for filename in files:
|
||||
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
|
||||
mid_name = filename[len(prefix):-5]
|
||||
flag = True
|
||||
for char in mid_name:
|
||||
if char.isalpha():
|
||||
flag = False
|
||||
if flag:
|
||||
self._ckpoint_filelist.append(directory + '/' + filename)
|
||||
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
self._ckpoint_filelist.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def remove_oldest_ckpoint_file(self):
|
||||
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
|
||||
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
|
||||
self.remove_ckpoint_file(ckpoint_files[0])
|
||||
|
||||
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
|
||||
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
|
||||
movs = []
|
||||
oldest_file = ''
|
||||
oldest_time = cur_time
|
||||
for ck_file in self._ckpoint_filelist:
|
||||
modify_time = os.path.getmtime(ck_file)
|
||||
if cur_time - modify_time < 60 * minutes:
|
||||
movs.append(ck_file)
|
||||
|
||||
if modify_time < oldest_time:
|
||||
oldest_time = modify_time
|
||||
oldest_file = ck_file
|
||||
|
||||
for mv_file in movs:
|
||||
if mv_file == oldest_file:
|
||||
continue
|
||||
self.remove_ckpoint_file(mv_file)
|
||||
|
||||
|
||||
def _check_file_name_prefix(file_name_prefix):
|
||||
"""
|
||||
|
@ -234,282 +166,6 @@ class CheckpointConfig:
|
|||
return checkpoint_policy
|
||||
|
||||
|
||||
def _set_cur_net(net):
|
||||
"""
|
||||
Set current net for which we are using to save checkpoint.
|
||||
|
||||
Args:
|
||||
net (Cell): train network
|
||||
"""
|
||||
global _cur_net
|
||||
_cur_net = net
|
||||
|
||||
|
||||
def _checkpoint_cb_for_save_op(parameter_list):
|
||||
"""
|
||||
The checkpoint callback function for MindSpore.
|
||||
|
||||
Will be executed by checkpoint save op.
|
||||
|
||||
Args:
|
||||
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
|
||||
|
||||
Returns:
|
||||
bool, true: means save checkpoint success.
|
||||
"""
|
||||
if _cur_net is None:
|
||||
logger.warning("_cur_net is None. parameters are not updated.")
|
||||
return False
|
||||
|
||||
logger.info("update parameters in the net.")
|
||||
_fill_param_into_net(_cur_net, parameter_list)
|
||||
_set_cur_net(None)
|
||||
return True
|
||||
|
||||
|
||||
def _summary_cb_for_save_op(summary_list):
|
||||
"""
|
||||
The summary callback function for MindSpore.
|
||||
|
||||
Will be executed by summary op.
|
||||
|
||||
Args:
|
||||
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
|
||||
|
||||
Returns:
|
||||
bool, true: means save summary success.
|
||||
"""
|
||||
ret = _cache_summary_tensor_data(summary_list)
|
||||
return ret
|
||||
|
||||
|
||||
class Callback:
|
||||
"""
|
||||
Abstract base class used to build a callback class. Callbacks are context managers
|
||||
which will be entered and exited when passing into the Model.
|
||||
You can leverage this mechanism to init and release resources automatically.
|
||||
|
||||
Callback function will execution some operating to the current step or epoch.
|
||||
|
||||
Examples:
|
||||
>>> class Print_info(Callback):
|
||||
>>> def step_end(self, run_context):
|
||||
>>> cb_params = run_context.original_args()
|
||||
>>> print(cb_params.cur_epoch_num)
|
||||
>>> print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> print_cb = Print_info()
|
||||
>>> model.train(epoch, dataset, callbacks=print_cb)
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
"""Return the enter target."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
"""Release resources here if have any."""
|
||||
|
||||
def begin(self, run_context):
|
||||
"""
|
||||
Called once before the network executing.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""
|
||||
Called before each epoch beginning.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Called after each epoch finished.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""
|
||||
Called before each epoch beginning.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Called after each step finished.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Called once after network training.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
|
||||
|
||||
class _CallbackManager(Callback):
|
||||
"""
|
||||
Sequential execution of callback functions.
|
||||
|
||||
Execute Callback functions at certain points.
|
||||
|
||||
Args:
|
||||
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks):
|
||||
self._callbacks, self._stack = [], None
|
||||
if isinstance(callbacks, Callback):
|
||||
self._callbacks.append(callbacks)
|
||||
elif callbacks is not None:
|
||||
for cb in callbacks:
|
||||
if not isinstance(cb, Callback):
|
||||
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
||||
self._callbacks.append(cb)
|
||||
|
||||
def __enter__(self):
|
||||
if self._stack is None:
|
||||
self._stack = ExitStack().__enter__()
|
||||
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._stack.__exit__(*err)
|
||||
|
||||
def begin(self, run_context):
|
||||
"""Called once before network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.begin(run_context)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_begin(run_context)
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Called after each epoch finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.epoch_end(run_context)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
"""Called before each epoch begin."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_begin(run_context)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Called after each step finished."""
|
||||
for cb in self._callbacks:
|
||||
cb.step_end(run_context)
|
||||
|
||||
def end(self, run_context):
|
||||
"""Called once after network training."""
|
||||
for cb in self._callbacks:
|
||||
cb.end(run_context)
|
||||
|
||||
|
||||
|
||||
class SummaryStep(Callback):
|
||||
"""
|
||||
The summary callback class.
|
||||
|
||||
Args:
|
||||
summary (Object): Summary recode object.
|
||||
flush_step (int): Number of interval steps to execute. Default: 10.
|
||||
"""
|
||||
def __init__(self, summary, flush_step=10):
|
||||
super(SummaryStep, self).__init__()
|
||||
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
|
||||
raise ValueError("`flush_step` should be int and greater than 0")
|
||||
self._summary = summary
|
||||
self._flush_step = flush_step
|
||||
def __enter__(self):
|
||||
self._summary.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._summary.__exit__(*err)
|
||||
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save summary.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self._flush_step == 0:
|
||||
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
|
||||
|
||||
@property
|
||||
def summary_file_name(self):
|
||||
return self._summary.full_file_name
|
||||
|
||||
|
||||
class _InternalCallbackParam(dict):
|
||||
"""Internal callback object's parameters."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
|
||||
class RunContext:
|
||||
"""
|
||||
Provides information about the model.
|
||||
|
||||
Run call being made. Provides information about original request to model function.
|
||||
callback objects can stop the loop by calling request_stop() of run_context.
|
||||
|
||||
Args:
|
||||
original_args (dict): Holding the related information of model etc.
|
||||
"""
|
||||
def __init__(self, original_args):
|
||||
if not isinstance(original_args, dict):
|
||||
raise TypeError("The arg of RunContext should be dict type.")
|
||||
self._original_args = original_args
|
||||
self._stop_requested = False
|
||||
|
||||
def original_args(self):
|
||||
"""
|
||||
Get the _original_args object.
|
||||
|
||||
Returns:
|
||||
Dict, a object holding the original arguments of model.
|
||||
"""
|
||||
return self._original_args
|
||||
|
||||
def request_stop(self):
|
||||
"""
|
||||
Sets stop requested during training.
|
||||
|
||||
Callbacks can use this function to request stop of iterations.
|
||||
model.train() checks whether this is called or not.
|
||||
"""
|
||||
self._stop_requested = True
|
||||
|
||||
def get_stop_requested(self):
|
||||
"""
|
||||
Returns whether a stop is requested or not.
|
||||
|
||||
Returns:
|
||||
bool, if true, model.train() stops iterations.
|
||||
"""
|
||||
return self._stop_requested
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
"""
|
||||
|
@ -553,7 +209,7 @@ class ModelCheckpoint(Callback):
|
|||
self._config = config
|
||||
|
||||
# get existing checkpoint files
|
||||
self._manager = _CheckpointManager()
|
||||
self._manager = CheckpointManager()
|
||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||
self._graph_saved = False
|
||||
|
||||
|
@ -633,7 +289,7 @@ class ModelCheckpoint(Callback):
|
|||
self._last_triggered_step = cb_params.cur_step_num
|
||||
|
||||
if context.get_context("enable_ge"):
|
||||
_set_cur_net(cb_params.train_network)
|
||||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||
|
@ -648,57 +304,66 @@ class ModelCheckpoint(Callback):
|
|||
return self._latest_ckpt_file_name
|
||||
|
||||
|
||||
class LossMonitor(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
class CheckpointManager:
|
||||
"""Manage checkpoint files according to train_config of checkpoint."""
|
||||
def __init__(self):
|
||||
self._ckpoint_filelist = []
|
||||
|
||||
If the loss is NAN or INF, it will terminate training.
|
||||
@property
|
||||
def ckpoint_filelist(self):
|
||||
"""Get all the related checkpoint files managed here."""
|
||||
return self._ckpoint_filelist
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
@property
|
||||
def ckpoint_num(self):
|
||||
"""Get the number of the related checkpoint files managed here."""
|
||||
return len(self._ckpoint_filelist)
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
def update_ckpoint_filelist(self, directory, prefix):
|
||||
"""Update the checkpoint file list."""
|
||||
self._ckpoint_filelist = []
|
||||
files = os.listdir(directory)
|
||||
for filename in files:
|
||||
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
|
||||
mid_name = filename[len(prefix):-5]
|
||||
flag = True
|
||||
for char in mid_name:
|
||||
if char.isalpha():
|
||||
flag = False
|
||||
if flag:
|
||||
self._ckpoint_filelist.append(directory + '/' + filename)
|
||||
|
||||
Raises:
|
||||
ValueError: If print_step is not int or less than zero.
|
||||
"""
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossMonitor, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
def remove_ckpoint_file(self, file_name):
|
||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||
try:
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
self._ckpoint_filelist.remove(file_name)
|
||||
except OSError:
|
||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||
except ValueError:
|
||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
def remove_oldest_ckpoint_file(self):
|
||||
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
|
||||
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
|
||||
self.remove_ckpoint_file(ckpoint_files[0])
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
|
||||
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
|
||||
movs = []
|
||||
oldest_file = ''
|
||||
oldest_time = cur_time
|
||||
for ck_file in self._ckpoint_filelist:
|
||||
modify_time = os.path.getmtime(ck_file)
|
||||
if cur_time - modify_time < 60 * minutes:
|
||||
movs.append(ck_file)
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
if modify_time < oldest_time:
|
||||
oldest_time = modify_time
|
||||
oldest_file = ck_file
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
|
||||
.format(cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""Time Monitor."""
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
for mv_file in movs:
|
||||
if mv_file == oldest_file:
|
||||
continue
|
||||
self.remove_ckpoint_file(mv_file)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""LossMonitor Callback class."""
|
||||
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
class LossMonitor(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF, it will terminate training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
|
||||
Raises:
|
||||
ValueError: If print_step is not int or less than zero.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossMonitor, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SummaryStep Callback class."""
|
||||
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
class SummaryStep(Callback):
|
||||
"""
|
||||
The summary callback class.
|
||||
|
||||
Args:
|
||||
summary (Object): Summary recode object.
|
||||
flush_step (int): Number of interval steps to execute. Default: 10.
|
||||
"""
|
||||
|
||||
def __init__(self, summary, flush_step=10):
|
||||
super(SummaryStep, self).__init__()
|
||||
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
|
||||
raise ValueError("`flush_step` should be int and greater than 0")
|
||||
self._summary = summary
|
||||
self._flush_step = flush_step
|
||||
|
||||
def __enter__(self):
|
||||
self._summary.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._summary.__exit__(*err)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save summary.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self._flush_step == 0:
|
||||
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
|
||||
|
||||
@property
|
||||
def summary_file_name(self):
|
||||
return self._summary.full_file_name
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""TimeMonitor Callback class."""
|
||||
|
||||
import time
|
||||
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""Time Monitor."""
|
||||
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
|
@ -19,7 +19,7 @@ from mindspore import log as logger
|
|||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from .dataset_helper import DatasetHelper
|
||||
|
|
|
@ -26,10 +26,10 @@ from mindspore.common.api import ms_function
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
|
||||
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
||||
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
|
||||
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
|
||||
_CallbackManager, Callback, CheckpointConfig
|
||||
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition."""
|
||||
|
@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op():
|
|||
one_param['name'] = "conv1.weight"
|
||||
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
|
||||
parameter_list.append(one_param)
|
||||
_checkpoint_cb_for_save_op(parameter_list)
|
||||
checkpoint_cb_for_save_op(parameter_list)
|
||||
|
||||
|
||||
def test_checkpoint_cb_for_save_op_update_net():
|
||||
|
@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net():
|
|||
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
|
||||
parameter_list.append(one_param)
|
||||
net = Net()
|
||||
_set_cur_net(net)
|
||||
_checkpoint_cb_for_save_op(parameter_list)
|
||||
set_cur_net(net)
|
||||
checkpoint_cb_for_save_op(parameter_list)
|
||||
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
|
|||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback.callback import _CheckpointManager
|
||||
from mindspore.train.callback import _CheckpointManager
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
||||
_exec_save_checkpoint, export, _save_graph
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
|
Loading…
Reference in New Issue