forked from mindspore-Ecosystem/mindspore
!2167 Add a callback module to avoid the size of the callback.py file too large
Merge pull request !2167 from ougongchang/adjust_callback
This commit is contained in:
commit
bb622877e8
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from model.dataset_helper import DatasetHelper
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace callbacks {
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback";
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
|
||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
|
||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
|
||||
const char kSummary[] = "Summary";
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace callbacks {
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback";
|
||||
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
|
||||
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
|
||||
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
|
||||
const char kSummary[] = "Summary";
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Callback related classes and functions."""
|
||||
|
||||
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext
|
||||
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
||||
"SummaryStep", "CheckpointConfig", "RunContext"]
|
|
@ -26,10 +26,7 @@ from mindspore.train._utils import _make_directory
|
|||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_int_non_negative, check_bool
|
||||
from mindspore.common.tensor import Tensor
|
||||
from .summary.summary_record import _cache_summary_tensor_data
|
||||
|
||||
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
|
||||
from mindspore.train.summary.summary_record import _cache_summary_tensor_data
|
||||
|
||||
|
||||
_cur_dir = os.getcwd()
|
|
@ -19,7 +19,7 @@ from mindspore import log as logger
|
|||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||
from .callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from .dataset_helper import DatasetHelper
|
||||
|
|
|
@ -25,8 +25,8 @@ from mindspore.common.api import ms_function
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \
|
||||
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
||||
from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
|
||||
_checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
||||
_build_callbacks, CheckpointConfig, _set_cur_net
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
|
|||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import _CheckpointManager
|
||||
from mindspore.train.callback.callback import _CheckpointManager
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
||||
_exec_save_checkpoint, export, _save_graph
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
|
Loading…
Reference in New Issue