!2167 Add a callback module to avoid the size of the callback.py file too large

Merge pull request !2167 from ougongchang/adjust_callback
This commit is contained in:
mindspore-ci-bot 2020-06-17 11:25:19 +08:00 committed by Gitee
commit bb622877e8
9 changed files with 29 additions and 12 deletions

View File

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from model.dataset_helper import DatasetHelper from model.dataset_helper import DatasetHelper

View File

@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace callbacks { namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback"; const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char kSummary[] = "Summary"; const char kSummary[] = "Summary";

View File

@ -25,7 +25,7 @@
namespace mindspore { namespace mindspore {
namespace callbacks { namespace callbacks {
const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback"; const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback";
const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op";
const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op";
const char kSummary[] = "Summary"; const char kSummary[] = "Summary";

View File

@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryStep", "CheckpointConfig", "RunContext"]

View File

@ -26,10 +26,7 @@ from mindspore.train._utils import _make_directory
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import check_int_non_negative, check_bool from mindspore._checkparam import check_int_non_negative, check_bool
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from .summary.summary_record import _cache_summary_tensor_data from mindspore.train.summary.summary_record import _cache_summary_tensor_data
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
_cur_dir = os.getcwd() _cur_dir = os.getcwd()

View File

@ -19,7 +19,7 @@ from mindspore import log as logger
from ..common.tensor import Tensor from ..common.tensor import Tensor
from ..nn.metrics import get_metrics from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback import _InternalCallbackParam, RunContext, _build_callbacks from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from .. import context from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check

View File

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper

View File

@ -25,8 +25,8 @@ from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \ from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
_build_callbacks, CheckpointConfig, _set_cur_net _build_callbacks, CheckpointConfig, _set_cur_net

View File

@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.callback import _CheckpointManager from mindspore.train.callback.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
_exec_save_checkpoint, export, _save_graph _exec_save_checkpoint, export, _save_graph
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine