forked from mindspore-Ecosystem/mindspore
callback module in encapsulated
This commit is contained in:
parent
fc74606211
commit
ecc459158e
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
from mindspore.train import amp
|
from mindspore.train import amp
|
||||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
|
||||||
from model.dataset_helper import DatasetHelper
|
from model.dataset_helper import DatasetHelper
|
||||||
|
|
|
@ -26,9 +26,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace callbacks {
|
namespace callbacks {
|
||||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
|
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
|
||||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
|
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
|
||||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
|
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
|
||||||
const char kSummary[] = "Summary";
|
const char kSummary[] = "Summary";
|
||||||
const char kCheckPoint[] = "Save";
|
const char kCheckPoint[] = "Save";
|
||||||
const int ONE_SHAPE = 1;
|
const int ONE_SHAPE = 1;
|
||||||
|
|
|
@ -25,9 +25,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace callbacks {
|
namespace callbacks {
|
||||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
|
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
|
||||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
|
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
|
||||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
|
const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
|
||||||
const char kSummary[] = "Summary";
|
const char kSummary[] = "Summary";
|
||||||
const char kCheckPoint[] = "Save";
|
const char kCheckPoint[] = "Save";
|
||||||
const int ONE_SHAPE = 1;
|
const int ONE_SHAPE = 1;
|
||||||
|
|
|
@ -14,7 +14,15 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Callback related classes and functions."""
|
"""Callback related classes and functions."""
|
||||||
|
|
||||||
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext
|
from ._callback import Callback
|
||||||
|
from ._callback import CallbackManager as _CallbackManager
|
||||||
|
from ._callback import InternalCallbackParam as _InternalCallbackParam
|
||||||
|
from ._callback import RunContext
|
||||||
|
from ._checkpoint import CheckpointConfig
|
||||||
|
from ._checkpoint import CheckpointManager as _CheckpointManager
|
||||||
|
from ._checkpoint import ModelCheckpoint
|
||||||
|
from ._loss_monitor import LossMonitor
|
||||||
|
from ._summary_step import SummaryStep
|
||||||
|
from ._time_monitor import TimeMonitor
|
||||||
|
|
||||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
|
||||||
"SummaryStep", "CheckpointConfig", "RunContext"]
|
|
||||||
|
|
|
@ -0,0 +1,260 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Callback related classes and functions."""
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
|
from mindspore import log as logger
|
||||||
|
from mindspore.train.serialization import _fill_param_into_net
|
||||||
|
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
|
||||||
|
|
||||||
|
_cur_net = None
|
||||||
|
|
||||||
|
def set_cur_net(net):
|
||||||
|
"""
|
||||||
|
Set current net for which we are using to save checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
net (Cell): train network
|
||||||
|
"""
|
||||||
|
global _cur_net
|
||||||
|
_cur_net = net
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_cb_for_save_op(parameter_list):
|
||||||
|
"""
|
||||||
|
The checkpoint callback function for MindSpore.
|
||||||
|
|
||||||
|
Will be executed by checkpoint save op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, true: means save checkpoint success.
|
||||||
|
"""
|
||||||
|
if _cur_net is None:
|
||||||
|
logger.warning("_cur_net is None. parameters are not updated.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("update parameters in the net.")
|
||||||
|
_fill_param_into_net(_cur_net, parameter_list)
|
||||||
|
set_cur_net(None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def summary_cb_for_save_op(summary_list):
|
||||||
|
"""
|
||||||
|
The summary callback function for MindSpore.
|
||||||
|
|
||||||
|
Will be executed by summary op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, true: means save summary success.
|
||||||
|
"""
|
||||||
|
ret = _cache_summary_tensor_data(summary_list)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class Callback:
|
||||||
|
"""
|
||||||
|
Abstract base class used to build a callback class. Callbacks are context managers
|
||||||
|
which will be entered and exited when passing into the Model.
|
||||||
|
You can leverage this mechanism to init and release resources automatically.
|
||||||
|
|
||||||
|
Callback function will execution some operating to the current step or epoch.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class Print_info(Callback):
|
||||||
|
>>> def step_end(self, run_context):
|
||||||
|
>>> cb_params = run_context.original_args()
|
||||||
|
>>> print(cb_params.cur_epoch_num)
|
||||||
|
>>> print(cb_params.cur_step_num)
|
||||||
|
>>>
|
||||||
|
>>> print_cb = Print_info()
|
||||||
|
>>> model.train(epoch, dataset, callbacks=print_cb)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Return the enter target."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
"""Release resources here if have any."""
|
||||||
|
|
||||||
|
def begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called once before the network executing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def epoch_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before each epoch beginning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after each epoch finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def step_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before each epoch beginning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after each step finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called once after network training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackManager(Callback):
|
||||||
|
"""
|
||||||
|
Sequential execution of callback functions.
|
||||||
|
|
||||||
|
Execute Callback functions at certain points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callbacks):
|
||||||
|
self._callbacks, self._stack = [], None
|
||||||
|
if isinstance(callbacks, Callback):
|
||||||
|
self._callbacks.append(callbacks)
|
||||||
|
elif callbacks is not None:
|
||||||
|
for cb in callbacks:
|
||||||
|
if not isinstance(cb, Callback):
|
||||||
|
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
||||||
|
self._callbacks.append(cb)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self._stack is None:
|
||||||
|
self._stack = ExitStack().__enter__()
|
||||||
|
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
return self._stack.__exit__(*err)
|
||||||
|
|
||||||
|
def begin(self, run_context):
|
||||||
|
"""Called once before network training."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.begin(run_context)
|
||||||
|
|
||||||
|
def epoch_begin(self, run_context):
|
||||||
|
"""Called before each epoch begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.epoch_begin(run_context)
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""Called after each epoch finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.epoch_end(run_context)
|
||||||
|
|
||||||
|
def step_begin(self, run_context):
|
||||||
|
"""Called before each epoch begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.step_begin(run_context)
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""Called after each step finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.step_end(run_context)
|
||||||
|
|
||||||
|
def end(self, run_context):
|
||||||
|
"""Called once after network training."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.end(run_context)
|
||||||
|
|
||||||
|
|
||||||
|
class InternalCallbackParam(dict):
|
||||||
|
"""Internal callback object's parameters."""
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return self[key]
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
class RunContext:
|
||||||
|
"""
|
||||||
|
Provides information about the model.
|
||||||
|
|
||||||
|
Run call being made. Provides information about original request to model function.
|
||||||
|
callback objects can stop the loop by calling request_stop() of run_context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_args (dict): Holding the related information of model etc.
|
||||||
|
"""
|
||||||
|
def __init__(self, original_args):
|
||||||
|
if not isinstance(original_args, dict):
|
||||||
|
raise TypeError("The arg of RunContext should be dict type.")
|
||||||
|
self._original_args = original_args
|
||||||
|
self._stop_requested = False
|
||||||
|
|
||||||
|
def original_args(self):
|
||||||
|
"""
|
||||||
|
Get the _original_args object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict, a object holding the original arguments of model.
|
||||||
|
"""
|
||||||
|
return self._original_args
|
||||||
|
|
||||||
|
def request_stop(self):
|
||||||
|
"""
|
||||||
|
Sets stop requested during training.
|
||||||
|
|
||||||
|
Callbacks can use this function to request stop of iterations.
|
||||||
|
model.train() checks whether this is called or not.
|
||||||
|
"""
|
||||||
|
self._stop_requested = True
|
||||||
|
|
||||||
|
def get_stop_requested(self):
|
||||||
|
"""
|
||||||
|
Returns whether a stop is requested or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool, if true, model.train() stops iterations.
|
||||||
|
"""
|
||||||
|
return self._stop_requested
|
|
@ -12,93 +12,25 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Callback related classes and functions."""
|
"""Checkpoint related classes and functions."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import stat
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import stat
|
||||||
import time
|
import time
|
||||||
from contextlib import ExitStack
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
|
|
||||||
from mindspore.train._utils import _make_directory
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import check_int_non_negative, check_bool
|
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.train._utils import _make_directory
|
||||||
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
|
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||||
|
|
||||||
|
from ._callback import Callback, set_cur_net
|
||||||
|
|
||||||
_cur_dir = os.getcwd()
|
_cur_dir = os.getcwd()
|
||||||
_cur_net = None
|
|
||||||
_save_dir = _cur_dir
|
_save_dir = _cur_dir
|
||||||
|
|
||||||
|
|
||||||
class _CheckpointManager:
|
|
||||||
"""Manage checkpoint files according to train_config of checkpoint."""
|
|
||||||
def __init__(self):
|
|
||||||
self._ckpoint_filelist = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ckpoint_filelist(self):
|
|
||||||
"""Get all the related checkpoint files managed here."""
|
|
||||||
return self._ckpoint_filelist
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ckpoint_num(self):
|
|
||||||
"""Get the number of the related checkpoint files managed here."""
|
|
||||||
return len(self._ckpoint_filelist)
|
|
||||||
|
|
||||||
def update_ckpoint_filelist(self, directory, prefix):
|
|
||||||
"""Update the checkpoint file list."""
|
|
||||||
self._ckpoint_filelist = []
|
|
||||||
files = os.listdir(directory)
|
|
||||||
for filename in files:
|
|
||||||
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
|
|
||||||
mid_name = filename[len(prefix):-5]
|
|
||||||
flag = True
|
|
||||||
for char in mid_name:
|
|
||||||
if char.isalpha():
|
|
||||||
flag = False
|
|
||||||
if flag:
|
|
||||||
self._ckpoint_filelist.append(directory + '/' + filename)
|
|
||||||
|
|
||||||
def remove_ckpoint_file(self, file_name):
|
|
||||||
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
|
||||||
try:
|
|
||||||
os.chmod(file_name, stat.S_IWRITE)
|
|
||||||
os.remove(file_name)
|
|
||||||
self._ckpoint_filelist.remove(file_name)
|
|
||||||
except OSError:
|
|
||||||
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
|
||||||
|
|
||||||
def remove_oldest_ckpoint_file(self):
|
|
||||||
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
|
|
||||||
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
|
|
||||||
self.remove_ckpoint_file(ckpoint_files[0])
|
|
||||||
|
|
||||||
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
|
|
||||||
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
|
|
||||||
movs = []
|
|
||||||
oldest_file = ''
|
|
||||||
oldest_time = cur_time
|
|
||||||
for ck_file in self._ckpoint_filelist:
|
|
||||||
modify_time = os.path.getmtime(ck_file)
|
|
||||||
if cur_time - modify_time < 60 * minutes:
|
|
||||||
movs.append(ck_file)
|
|
||||||
|
|
||||||
if modify_time < oldest_time:
|
|
||||||
oldest_time = modify_time
|
|
||||||
oldest_file = ck_file
|
|
||||||
|
|
||||||
for mv_file in movs:
|
|
||||||
if mv_file == oldest_file:
|
|
||||||
continue
|
|
||||||
self.remove_ckpoint_file(mv_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_file_name_prefix(file_name_prefix):
|
def _check_file_name_prefix(file_name_prefix):
|
||||||
"""
|
"""
|
||||||
|
@ -234,282 +166,6 @@ class CheckpointConfig:
|
||||||
return checkpoint_policy
|
return checkpoint_policy
|
||||||
|
|
||||||
|
|
||||||
def _set_cur_net(net):
|
|
||||||
"""
|
|
||||||
Set current net for which we are using to save checkpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
net (Cell): train network
|
|
||||||
"""
|
|
||||||
global _cur_net
|
|
||||||
_cur_net = net
|
|
||||||
|
|
||||||
|
|
||||||
def _checkpoint_cb_for_save_op(parameter_list):
|
|
||||||
"""
|
|
||||||
The checkpoint callback function for MindSpore.
|
|
||||||
|
|
||||||
Will be executed by checkpoint save op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool, true: means save checkpoint success.
|
|
||||||
"""
|
|
||||||
if _cur_net is None:
|
|
||||||
logger.warning("_cur_net is None. parameters are not updated.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info("update parameters in the net.")
|
|
||||||
_fill_param_into_net(_cur_net, parameter_list)
|
|
||||||
_set_cur_net(None)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _summary_cb_for_save_op(summary_list):
|
|
||||||
"""
|
|
||||||
The summary callback function for MindSpore.
|
|
||||||
|
|
||||||
Will be executed by summary op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool, true: means save summary success.
|
|
||||||
"""
|
|
||||||
ret = _cache_summary_tensor_data(summary_list)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class Callback:
|
|
||||||
"""
|
|
||||||
Abstract base class used to build a callback class. Callbacks are context managers
|
|
||||||
which will be entered and exited when passing into the Model.
|
|
||||||
You can leverage this mechanism to init and release resources automatically.
|
|
||||||
|
|
||||||
Callback function will execution some operating to the current step or epoch.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> class Print_info(Callback):
|
|
||||||
>>> def step_end(self, run_context):
|
|
||||||
>>> cb_params = run_context.original_args()
|
|
||||||
>>> print(cb_params.cur_epoch_num)
|
|
||||||
>>> print(cb_params.cur_step_num)
|
|
||||||
>>>
|
|
||||||
>>> print_cb = Print_info()
|
|
||||||
>>> model.train(epoch, dataset, callbacks=print_cb)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Return the enter target."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *err):
|
|
||||||
"""Release resources here if have any."""
|
|
||||||
|
|
||||||
def begin(self, run_context):
|
|
||||||
"""
|
|
||||||
Called once before the network executing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Include some information of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def epoch_begin(self, run_context):
|
|
||||||
"""
|
|
||||||
Called before each epoch beginning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Include some information of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
|
||||||
"""
|
|
||||||
Called after each epoch finished.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Include some information of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def step_begin(self, run_context):
|
|
||||||
"""
|
|
||||||
Called before each epoch beginning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Include some information of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def step_end(self, run_context):
|
|
||||||
"""
|
|
||||||
Called after each step finished.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Include some information of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def end(self, run_context):
|
|
||||||
"""
|
|
||||||
Called once after network training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Include some information of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class _CallbackManager(Callback):
|
|
||||||
"""
|
|
||||||
Sequential execution of callback functions.
|
|
||||||
|
|
||||||
Execute Callback functions at certain points.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, callbacks):
|
|
||||||
self._callbacks, self._stack = [], None
|
|
||||||
if isinstance(callbacks, Callback):
|
|
||||||
self._callbacks.append(callbacks)
|
|
||||||
elif callbacks is not None:
|
|
||||||
for cb in callbacks:
|
|
||||||
if not isinstance(cb, Callback):
|
|
||||||
raise TypeError("%r is not an instance of %r" % (cb, Callback))
|
|
||||||
self._callbacks.append(cb)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if self._stack is None:
|
|
||||||
self._stack = ExitStack().__enter__()
|
|
||||||
self._callbacks = [self._stack.enter_context(cb) for cb in self._callbacks]
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *err):
|
|
||||||
return self._stack.__exit__(*err)
|
|
||||||
|
|
||||||
def begin(self, run_context):
|
|
||||||
"""Called once before network training."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.begin(run_context)
|
|
||||||
|
|
||||||
def epoch_begin(self, run_context):
|
|
||||||
"""Called before each epoch begin."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.epoch_begin(run_context)
|
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
|
||||||
"""Called after each epoch finished."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.epoch_end(run_context)
|
|
||||||
|
|
||||||
def step_begin(self, run_context):
|
|
||||||
"""Called before each epoch begin."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.step_begin(run_context)
|
|
||||||
|
|
||||||
def step_end(self, run_context):
|
|
||||||
"""Called after each step finished."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.step_end(run_context)
|
|
||||||
|
|
||||||
def end(self, run_context):
|
|
||||||
"""Called once after network training."""
|
|
||||||
for cb in self._callbacks:
|
|
||||||
cb.end(run_context)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SummaryStep(Callback):
|
|
||||||
"""
|
|
||||||
The summary callback class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
summary (Object): Summary recode object.
|
|
||||||
flush_step (int): Number of interval steps to execute. Default: 10.
|
|
||||||
"""
|
|
||||||
def __init__(self, summary, flush_step=10):
|
|
||||||
super(SummaryStep, self).__init__()
|
|
||||||
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
|
|
||||||
raise ValueError("`flush_step` should be int and greater than 0")
|
|
||||||
self._summary = summary
|
|
||||||
self._flush_step = flush_step
|
|
||||||
def __enter__(self):
|
|
||||||
self._summary.__enter__()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *err):
|
|
||||||
return self._summary.__exit__(*err)
|
|
||||||
|
|
||||||
|
|
||||||
def step_end(self, run_context):
|
|
||||||
"""
|
|
||||||
Save summary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Context of the train running.
|
|
||||||
"""
|
|
||||||
cb_params = run_context.original_args()
|
|
||||||
if cb_params.cur_step_num % self._flush_step == 0:
|
|
||||||
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def summary_file_name(self):
|
|
||||||
return self._summary.full_file_name
|
|
||||||
|
|
||||||
|
|
||||||
class _InternalCallbackParam(dict):
|
|
||||||
"""Internal callback object's parameters."""
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
return self[key]
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
self[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
class RunContext:
|
|
||||||
"""
|
|
||||||
Provides information about the model.
|
|
||||||
|
|
||||||
Run call being made. Provides information about original request to model function.
|
|
||||||
callback objects can stop the loop by calling request_stop() of run_context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_args (dict): Holding the related information of model etc.
|
|
||||||
"""
|
|
||||||
def __init__(self, original_args):
|
|
||||||
if not isinstance(original_args, dict):
|
|
||||||
raise TypeError("The arg of RunContext should be dict type.")
|
|
||||||
self._original_args = original_args
|
|
||||||
self._stop_requested = False
|
|
||||||
|
|
||||||
def original_args(self):
|
|
||||||
"""
|
|
||||||
Get the _original_args object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict, a object holding the original arguments of model.
|
|
||||||
"""
|
|
||||||
return self._original_args
|
|
||||||
|
|
||||||
def request_stop(self):
|
|
||||||
"""
|
|
||||||
Sets stop requested during training.
|
|
||||||
|
|
||||||
Callbacks can use this function to request stop of iterations.
|
|
||||||
model.train() checks whether this is called or not.
|
|
||||||
"""
|
|
||||||
self._stop_requested = True
|
|
||||||
|
|
||||||
def get_stop_requested(self):
|
|
||||||
"""
|
|
||||||
Returns whether a stop is requested or not.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool, if true, model.train() stops iterations.
|
|
||||||
"""
|
|
||||||
return self._stop_requested
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCheckpoint(Callback):
|
class ModelCheckpoint(Callback):
|
||||||
"""
|
"""
|
||||||
|
@ -553,7 +209,7 @@ class ModelCheckpoint(Callback):
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
# get existing checkpoint files
|
# get existing checkpoint files
|
||||||
self._manager = _CheckpointManager()
|
self._manager = CheckpointManager()
|
||||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||||
self._graph_saved = False
|
self._graph_saved = False
|
||||||
|
|
||||||
|
@ -633,7 +289,7 @@ class ModelCheckpoint(Callback):
|
||||||
self._last_triggered_step = cb_params.cur_step_num
|
self._last_triggered_step = cb_params.cur_step_num
|
||||||
|
|
||||||
if context.get_context("enable_ge"):
|
if context.get_context("enable_ge"):
|
||||||
_set_cur_net(cb_params.train_network)
|
set_cur_net(cb_params.train_network)
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
cb_params.train_network.exec_checkpoint_graph()
|
||||||
|
|
||||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||||
|
@ -648,57 +304,66 @@ class ModelCheckpoint(Callback):
|
||||||
return self._latest_ckpt_file_name
|
return self._latest_ckpt_file_name
|
||||||
|
|
||||||
|
|
||||||
class LossMonitor(Callback):
|
class CheckpointManager:
|
||||||
"""
|
"""Manage checkpoint files according to train_config of checkpoint."""
|
||||||
Monitor the loss in training.
|
def __init__(self):
|
||||||
|
self._ckpoint_filelist = []
|
||||||
|
|
||||||
If the loss is NAN or INF, it will terminate training.
|
@property
|
||||||
|
def ckpoint_filelist(self):
|
||||||
|
"""Get all the related checkpoint files managed here."""
|
||||||
|
return self._ckpoint_filelist
|
||||||
|
|
||||||
Note:
|
@property
|
||||||
If per_print_times is 0 do not print loss.
|
def ckpoint_num(self):
|
||||||
|
"""Get the number of the related checkpoint files managed here."""
|
||||||
|
return len(self._ckpoint_filelist)
|
||||||
|
|
||||||
Args:
|
def update_ckpoint_filelist(self, directory, prefix):
|
||||||
per_print_times (int): Print loss every times. Default: 1.
|
"""Update the checkpoint file list."""
|
||||||
|
self._ckpoint_filelist = []
|
||||||
|
files = os.listdir(directory)
|
||||||
|
for filename in files:
|
||||||
|
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
|
||||||
|
mid_name = filename[len(prefix):-5]
|
||||||
|
flag = True
|
||||||
|
for char in mid_name:
|
||||||
|
if char.isalpha():
|
||||||
|
flag = False
|
||||||
|
if flag:
|
||||||
|
self._ckpoint_filelist.append(directory + '/' + filename)
|
||||||
|
|
||||||
Raises:
|
def remove_ckpoint_file(self, file_name):
|
||||||
ValueError: If print_step is not int or less than zero.
|
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
|
||||||
"""
|
try:
|
||||||
def __init__(self, per_print_times=1):
|
os.chmod(file_name, stat.S_IWRITE)
|
||||||
super(LossMonitor, self).__init__()
|
os.remove(file_name)
|
||||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
self._ckpoint_filelist.remove(file_name)
|
||||||
raise ValueError("print_step must be int and >= 0.")
|
except OSError:
|
||||||
self._per_print_times = per_print_times
|
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
|
||||||
|
|
||||||
def step_end(self, run_context):
|
def remove_oldest_ckpoint_file(self):
|
||||||
cb_params = run_context.original_args()
|
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
|
||||||
loss = cb_params.net_outputs
|
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
|
||||||
|
self.remove_ckpoint_file(ckpoint_files[0])
|
||||||
|
|
||||||
if isinstance(loss, (tuple, list)):
|
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
|
||||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
|
||||||
loss = loss[0]
|
movs = []
|
||||||
|
oldest_file = ''
|
||||||
|
oldest_time = cur_time
|
||||||
|
for ck_file in self._ckpoint_filelist:
|
||||||
|
modify_time = os.path.getmtime(ck_file)
|
||||||
|
if cur_time - modify_time < 60 * minutes:
|
||||||
|
movs.append(ck_file)
|
||||||
|
|
||||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
if modify_time < oldest_time:
|
||||||
loss = np.mean(loss.asnumpy())
|
oldest_time = modify_time
|
||||||
|
oldest_file = ck_file
|
||||||
|
|
||||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
for mv_file in movs:
|
||||||
|
if mv_file == oldest_file:
|
||||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
continue
|
||||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
|
self.remove_ckpoint_file(mv_file)
|
||||||
.format(cb_params.cur_epoch_num, cur_step_in_epoch))
|
|
||||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
|
||||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeMonitor(Callback):
|
|
||||||
"""Time Monitor."""
|
|
||||||
def __init__(self, data_size):
|
|
||||||
super(TimeMonitor, self).__init__()
|
|
||||||
self.data_size = data_size
|
|
||||||
|
|
||||||
def epoch_begin(self, run_context):
|
|
||||||
self.epoch_time = time.time()
|
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
|
||||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
|
||||||
per_step_mseconds = epoch_mseconds / self.data_size
|
|
||||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""LossMonitor Callback class."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
|
||||||
|
from ._callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class LossMonitor(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
|
||||||
|
If the loss is NAN or INF, it will terminate training.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If per_print_times is 0 do not print loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
per_print_times (int): Print loss every times. Default: 1.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If print_step is not int or less than zero.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, per_print_times=1):
|
||||||
|
super(LossMonitor, self).__init__()
|
||||||
|
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||||
|
raise ValueError("print_step must be int and >= 0.")
|
||||||
|
self._per_print_times = per_print_times
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
loss = cb_params.net_outputs
|
||||||
|
|
||||||
|
if isinstance(loss, (tuple, list)):
|
||||||
|
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||||
|
loss = loss[0]
|
||||||
|
|
||||||
|
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||||
|
loss = np.mean(loss.asnumpy())
|
||||||
|
|
||||||
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||||
|
|
||||||
|
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||||
|
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||||
|
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||||
|
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||||
|
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""SummaryStep Callback class."""
|
||||||
|
|
||||||
|
from ._callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryStep(Callback):
|
||||||
|
"""
|
||||||
|
The summary callback class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary (Object): Summary recode object.
|
||||||
|
flush_step (int): Number of interval steps to execute. Default: 10.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, summary, flush_step=10):
|
||||||
|
super(SummaryStep, self).__init__()
|
||||||
|
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
|
||||||
|
raise ValueError("`flush_step` should be int and greater than 0")
|
||||||
|
self._summary = summary
|
||||||
|
self._flush_step = flush_step
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._summary.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *err):
|
||||||
|
return self._summary.__exit__(*err)
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Save summary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Context of the train running.
|
||||||
|
"""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
if cb_params.cur_step_num % self._flush_step == 0:
|
||||||
|
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summary_file_name(self):
|
||||||
|
return self._summary.full_file_name
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""TimeMonitor Callback class."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from ._callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMonitor(Callback):
|
||||||
|
"""Time Monitor."""
|
||||||
|
|
||||||
|
def __init__(self, data_size):
|
||||||
|
super(TimeMonitor, self).__init__()
|
||||||
|
self.data_size = data_size
|
||||||
|
|
||||||
|
def epoch_begin(self, run_context):
|
||||||
|
self.epoch_time = time.time()
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||||
|
per_step_mseconds = epoch_mseconds / self.data_size
|
||||||
|
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
|
@ -19,7 +19,7 @@ from mindspore import log as logger
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
from ..nn.metrics import get_metrics
|
from ..nn.metrics import get_metrics
|
||||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||||
from .callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from .. import context
|
from .. import context
|
||||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
from mindspore.train import amp
|
from mindspore.train import amp
|
||||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
|
||||||
from .dataset_helper import DatasetHelper
|
from .dataset_helper import DatasetHelper
|
||||||
|
|
|
@ -26,10 +26,10 @@ from mindspore.common.api import ms_function
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
from mindspore.nn.optim import Momentum
|
from mindspore.nn.optim import Momentum
|
||||||
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
|
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
|
||||||
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
_CallbackManager, Callback, CheckpointConfig
|
||||||
_CallbackManager, Callback, CheckpointConfig, _set_cur_net
|
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op
|
||||||
|
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
"""Net definition."""
|
"""Net definition."""
|
||||||
|
@ -187,7 +187,7 @@ def test_checkpoint_cb_for_save_op():
|
||||||
one_param['name'] = "conv1.weight"
|
one_param['name'] = "conv1.weight"
|
||||||
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
|
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
|
||||||
parameter_list.append(one_param)
|
parameter_list.append(one_param)
|
||||||
_checkpoint_cb_for_save_op(parameter_list)
|
checkpoint_cb_for_save_op(parameter_list)
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_cb_for_save_op_update_net():
|
def test_checkpoint_cb_for_save_op_update_net():
|
||||||
|
@ -198,8 +198,8 @@ def test_checkpoint_cb_for_save_op_update_net():
|
||||||
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
|
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32)
|
||||||
parameter_list.append(one_param)
|
parameter_list.append(one_param)
|
||||||
net = Net()
|
net = Net()
|
||||||
_set_cur_net(net)
|
set_cur_net(net)
|
||||||
_checkpoint_cb_for_save_op(parameter_list)
|
checkpoint_cb_for_save_op(parameter_list)
|
||||||
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
|
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||||
from mindspore.nn.optim.momentum import Momentum
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.train.callback.callback import _CheckpointManager
|
from mindspore.train.callback import _CheckpointManager
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
||||||
_exec_save_checkpoint, export, _save_graph
|
_exec_save_checkpoint, export, _save_graph
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
|
Loading…
Reference in New Issue