!2236 Refactor the callback module in an encapsulated way

Merge pull request !2236 from 李鸿章/callback
This commit is contained in:
mindspore-ci-bot 2020-06-18 12:09:19 +08:00 committed by Gitee
commit 0c3d96a98b
13 changed files with 504 additions and 418 deletions

View File

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from model.dataset_helper import DatasetHelper

View File

@ -26,9 +26,9 @@
namespace mindspore {
namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
const char kSummary[] = "Summary";
const char kCheckPoint[] = "Save";
const int ONE_SHAPE = 1;

View File

@ -25,9 +25,9 @@
namespace mindspore {
namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
const char kSummary[] = "Summary";
const char kCheckPoint[] = "Save";
const int ONE_SHAPE = 1;

View File

@ -14,7 +14,15 @@
# ============================================================================
"""Callback related classes and functions."""
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext
from ._callback import Callback
from ._callback import CallbackManager as _CallbackManager
from ._callback import InternalCallbackParam as _InternalCallbackParam
from ._callback import RunContext
from ._checkpoint import CheckpointConfig
from ._checkpoint import CheckpointManager as _CheckpointManager
from ._checkpoint import ModelCheckpoint
from ._loss_monitor import LossMonitor
from ._summary_step import SummaryStep
from ._time_monitor import TimeMonitor
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryStep", "CheckpointConfig", "RunContext"]
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]

View File

@ -0,0 +1,260 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""
from contextlib import ExitStack
from mindspore import log as logger
from mindspore.train.serialization import _fill_param_into_net
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
_cur_net = None
def set_cur_net(net):
"""
Set current net for which we are using to save checkpoint.
Args:
net (Cell): train network
"""
global _cur_net
_cur_net = net
def checkpoint_cb_for_save_op(parameter_list):
"""
The checkpoint callback function for MindSpore.
Will be executed by checkpoint save op.
Args:
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
Returns:
bool, true: means save checkpoint success.
"""
if _cur_net is None:
logger.warning("_cur_net is None. parameters are not updated.")
return False
logger.info("update parameters in the net.")
_fill_param_into_net(_cur_net, parameter_list)
set_cur_net(None)
return True
def summary_cb_for_save_op(summary_list):
"""
The summary callback function for MindSpore.
Will be executed by summary op.
Args:
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
Returns:
bool, true: means save summary success.
"""
ret = _cache_summary_tensor_data(summary_list)
return ret
class Callback:
"""
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
Callback function will execution some operating to the current step or epoch.
Examples:
>>> class Print_info(Callback):
>>> def step_end(self, run_context):
>>> cb_params = run_context.original_args()
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)
"""
def __enter__(self):
"""Return the enter target."""
return self
def __exit__(self, *err):
"""Release resources here if have any."""
def begin(self, run_context):
"""
Called once before the network executing.
Args:
run_context (RunContext): Include some information of the model.
"""
def epoch_begin(self, run_context):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def epoch_end(self, run_context):
"""
Called after each epoch finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def step_begin(self, run_context):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def step_end(self, run_context):
"""
Called after each step finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def end(self, run_context):
"""
Called once after network training.
Args:
run_context (RunContext): Include some information of the model.
"""
class CallbackManager(Callback):
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
"""
def __init__(self, callbacks):
self._callbacks, self._stack = [], None
if isinstance(callbacks, Callback):
self._callbacks.append(callbacks)
elif callbacks is not None:
for cb in callbacks:
if not isinstance(cb, Callback):
raise TypeError("%r is not an instance of %r" % (cb, Callback))
self._callbacks.append(cb)
def __enter__(self):
if self._stack is None:
self._stack = ExitStack().__enter__()
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
return self
def __exit__(self, *err):
return self._stack.__exit__(*err)
def begin(self, run_context):
"""Called once before network training."""
for cb in self._callbacks:
cb.begin(run_context)
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.epoch_begin(run_context)
def epoch_end(self, run_context):
"""Called after each epoch finished."""
for cb in self._callbacks:
cb.epoch_end(run_context)
def step_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.step_begin(run_context)
def step_end(self, run_context):
"""Called after each step finished."""
for cb in self._callbacks:
cb.step_end(run_context)
def end(self, run_context):
"""Called once after network training."""
for cb in self._callbacks:
cb.end(run_context)
class InternalCallbackParam(dict):
"""Internal callback object's parameters."""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
class RunContext:
"""
Provides information about the model.
Run call being made. Provides information about original request to model function.
callback objects can stop the loop by calling request_stop() of run_context.
Args:
original_args (dict): Holding the related information of model etc.
"""
def __init__(self, original_args):
if not isinstance(original_args, dict):
raise TypeError("The arg of RunContext should be dict type.")
self._original_args = original_args
self._stop_requested = False
def original_args(self):
"""
Get the _original_args object.
Returns:
Dict, a object holding the original arguments of model.
"""
return self._original_args
def request_stop(self):
"""
Sets stop requested during training.
Callbacks can use this function to request stop of iterations.
model.train() checks whether this is called or not.
"""
self._stop_requested = True
def get_stop_requested(self):
"""
Returns whether a stop is requested or not.
Returns:
bool, if true, model.train() stops iterations.
"""
return self._stop_requested

View File

@ -12,93 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""
"""Checkpoint related classes and functions."""
import os
import stat
import shutil
import stat
import time
from contextlib import ExitStack
import numpy as np
import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
from mindspore.train._utils import _make_directory
from mindspore import log as logger
from mindspore._checkparam import check_int_non_negative, check_bool
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net
_cur_dir = os.getcwd()
_cur_net = None
_save_dir = _cur_dir
class _CheckpointManager:
"""Manage checkpoint files according to train_config of checkpoint."""
def __init__(self):
self._ckpoint_filelist = []
@property
def ckpoint_filelist(self):
"""Get all the related checkpoint files managed here."""
return self._ckpoint_filelist
@property
def ckpoint_num(self):
"""Get the number of the related checkpoint files managed here."""
return len(self._ckpoint_filelist)
def update_ckpoint_filelist(self, directory, prefix):
"""Update the checkpoint file list."""
self._ckpoint_filelist = []
files = os.listdir(directory)
for filename in files:
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
mid_name = filename[len(prefix):-5]
flag = True
for char in mid_name:
if char.isalpha():
flag = False
if flag:
self._ckpoint_filelist.append(directory + '/' + filename)
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
self._ckpoint_filelist.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def remove_oldest_ckpoint_file(self):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
self.remove_ckpoint_file(ckpoint_files[0])
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
movs = []
oldest_file = ''
oldest_time = cur_time
for ck_file in self._ckpoint_filelist:
modify_time = os.path.getmtime(ck_file)
if cur_time - modify_time < 60 * minutes:
movs.append(ck_file)
if modify_time < oldest_time:
oldest_time = modify_time
oldest_file = ck_file
for mv_file in movs:
if mv_file == oldest_file:
continue
self.remove_ckpoint_file(mv_file)
def _check_file_name_prefix(file_name_prefix):
"""
@ -234,282 +166,6 @@ class CheckpointConfig:
return checkpoint_policy
def _set_cur_net(net):
"""
Set current net for which we are using to save checkpoint.
Args:
net (Cell): train network
"""
global _cur_net
_cur_net = net
def _checkpoint_cb_for_save_op(parameter_list):
"""
The checkpoint callback function for MindSpore.
Will be executed by checkpoint save op.
Args:
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
Returns:
bool, true: means save checkpoint success.
"""
if _cur_net is None:
logger.warning("_cur_net is None. parameters are not updated.")
return False
logger.info("update parameters in the net.")
_fill_param_into_net(_cur_net, parameter_list)
_set_cur_net(None)
return True
def _summary_cb_for_save_op(summary_list):
"""
The summary callback function for MindSpore.
Will be executed by summary op.
Args:
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
Returns:
bool, true: means save summary success.
"""
ret = _cache_summary_tensor_data(summary_list)
return ret
class Callback:
"""
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
Callback function will execution some operating to the current step or epoch.
Examples:
>>> class Print_info(Callback):
>>> def step_end(self, run_context):
>>> cb_params = run_context.original_args()
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> model.train(epoch, dataset, callbacks=print_cb)
"""
def __enter__(self):
"""Return the enter target."""
return self
def __exit__(self, *err):
"""Release resources here if have any."""
def begin(self, run_context):
"""
Called once before the network executing.
Args:
run_context (RunContext): Include some information of the model.
"""
def epoch_begin(self, run_context):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def epoch_end(self, run_context):
"""
Called after each epoch finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def step_begin(self, run_context):
"""
Called before each epoch beginning.
Args:
run_context (RunContext): Include some information of the model.
"""
def step_end(self, run_context):
"""
Called after each step finished.
Args:
run_context (RunContext): Include some information of the model.
"""
def end(self, run_context):
"""
Called once after network training.
Args:
run_context (RunContext): Include some information of the model.
"""
class _CallbackManager(Callback):
"""
Sequential execution of callback functions.
Execute Callback functions at certain points.
Args:
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
"""
def __init__(self, callbacks):
self._callbacks, self._stack = [], None
if isinstance(callbacks, Callback):
self._callbacks.append(callbacks)
elif callbacks is not None:
for cb in callbacks:
if not isinstance(cb, Callback):
raise TypeError("%r is not an instance of %r" % (cb, Callback))
self._callbacks.append(cb)
def __enter__(self):
if self._stack is None:
self._stack = ExitStack().__enter__()
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
return self
def __exit__(self, *err):
return self._stack.__exit__(*err)
def begin(self, run_context):
"""Called once before network training."""
for cb in self._callbacks:
cb.begin(run_context)
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.epoch_begin(run_context)
def epoch_end(self, run_context):
"""Called after each epoch finished."""
for cb in self._callbacks:
cb.epoch_end(run_context)
def step_begin(self, run_context):
"""Called before each epoch begin."""
for cb in self._callbacks:
cb.step_begin(run_context)
def step_end(self, run_context):
"""Called after each step finished."""
for cb in self._callbacks:
cb.step_end(run_context)
def end(self, run_context):
"""Called once after network training."""
for cb in self._callbacks:
cb.end(run_context)
class SummaryStep(Callback):
"""
The summary callback class.
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
raise ValueError("`flush_step` should be int and greater than 0")
self._summary = summary
self._flush_step = flush_step
def __enter__(self):
self._summary.__enter__()
return self
def __exit__(self, *err):
return self._summary.__exit__(*err)
def step_end(self, run_context):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self._flush_step == 0:
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
@property
def summary_file_name(self):
return self._summary.full_file_name
class _InternalCallbackParam(dict):
"""Internal callback object's parameters."""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
class RunContext:
"""
Provides information about the model.
Run call being made. Provides information about original request to model function.
callback objects can stop the loop by calling request_stop() of run_context.
Args:
original_args (dict): Holding the related information of model etc.
"""
def __init__(self, original_args):
if not isinstance(original_args, dict):
raise TypeError("The arg of RunContext should be dict type.")
self._original_args = original_args
self._stop_requested = False
def original_args(self):
"""
Get the _original_args object.
Returns:
Dict, a object holding the original arguments of model.
"""
return self._original_args
def request_stop(self):
"""
Sets stop requested during training.
Callbacks can use this function to request stop of iterations.
model.train() checks whether this is called or not.
"""
self._stop_requested = True
def get_stop_requested(self):
"""
Returns whether a stop is requested or not.
Returns:
bool, if true, model.train() stops iterations.
"""
return self._stop_requested
class ModelCheckpoint(Callback):
"""
@ -553,7 +209,7 @@ class ModelCheckpoint(Callback):
self._config = config
# get existing checkpoint files
self._manager = _CheckpointManager()
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._graph_saved = False
@ -633,7 +289,7 @@ class ModelCheckpoint(Callback):
self._last_triggered_step = cb_params.cur_step_num
if context.get_context("enable_ge"):
_set_cur_net(cb_params.train_network)
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
@ -648,57 +304,66 @@ class ModelCheckpoint(Callback):
return self._latest_ckpt_file_name
class LossMonitor(Callback):
"""
Monitor the loss in training.
class CheckpointManager:
"""Manage checkpoint files according to train_config of checkpoint."""
def __init__(self):
self._ckpoint_filelist = []
If the loss is NAN or INF, it will terminate training.
@property
def ckpoint_filelist(self):
"""Get all the related checkpoint files managed here."""
return self._ckpoint_filelist
Note:
If per_print_times is 0 do not print loss.
@property
def ckpoint_num(self):
"""Get the number of the related checkpoint files managed here."""
return len(self._ckpoint_filelist)
Args:
per_print_times (int): Print loss every times. Default: 1.
def update_ckpoint_filelist(self, directory, prefix):
"""Update the checkpoint file list."""
self._ckpoint_filelist = []
files = os.listdir(directory)
for filename in files:
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
mid_name = filename[len(prefix):-5]
flag = True
for char in mid_name:
if char.isalpha():
flag = False
if flag:
self._ckpoint_filelist.append(directory + '/' + filename)
Raises:
ValueError: If print_step is not int or less than zero.
"""
def __init__(self, per_print_times=1):
super(LossMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
self._ckpoint_filelist.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
def remove_oldest_ckpoint_file(self):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
self.remove_ckpoint_file(ckpoint_files[0])
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
movs = []
oldest_file = ''
oldest_time = cur_time
for ck_file in self._ckpoint_filelist:
modify_time = os.path.getmtime(ck_file)
if cur_time - modify_time < 60 * minutes:
movs.append(ck_file)
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
if modify_time < oldest_time:
oldest_time = modify_time
oldest_file = ck_file
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
.format(cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
class TimeMonitor(Callback):
"""Time Monitor."""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
for mv_file in movs:
if mv_file == oldest_file:
continue
self.remove_ckpoint_file(mv_file)

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LossMonitor Callback class."""
import numpy as np
from mindspore.common.tensor import Tensor
from ._callback import Callback
class LossMonitor(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
Raises:
ValueError: If print_step is not int or less than zero.
"""
def __init__(self, per_print_times=1):
super(LossMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SummaryStep Callback class."""
from ._callback import Callback
class SummaryStep(Callback):
"""
The summary callback class.
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
raise ValueError("`flush_step` should be int and greater than 0")
self._summary = summary
self._flush_step = flush_step
def __enter__(self):
self._summary.__enter__()
return self
def __exit__(self, *err):
return self._summary.__exit__(*err)
def step_end(self, run_context):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self._flush_step == 0:
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
@property
def summary_file_name(self):
return self._summary.full_file_name

View File

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TimeMonitor Callback class."""
import time
from ._callback import Callback
class TimeMonitor(Callback):
"""Time Monitor."""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)

View File

@ -19,7 +19,7 @@ from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check

View File

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper

View File

@ -26,10 +26,10 @@ from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
_CallbackManager, Callback, CheckpointConfig
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
class Net(nn.Cell):
"""Net definition."""
@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op():
one_param['name'] = "conv1.weight"
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
parameter_list.append(one_param)
_checkpoint_cb_for_save_op(parameter_list)
checkpoint_cb_for_save_op(parameter_list)
def test_checkpoint_cb_for_save_op_update_net():
@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net():
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
parameter_list.append(one_param)
net = Net()
_set_cur_net(net)
_checkpoint_cb_for_save_op(parameter_list)
set_cur_net(net)
checkpoint_cb_for_save_op(parameter_list)
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1

View File

@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train.callback.callback import _CheckpointManager
from mindspore.train.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
_exec_save_checkpoint, export, _save_graph
from ..ut_filter import non_graph_engine