From 1ffa37ffbc45529234b66a8cc81b1c83f7502a8d Mon Sep 17 00:00:00 2001 From: ougongchang Date: Tue, 16 Jun 2020 12:48:03 +0800 Subject: [PATCH] Add a callback module to avoid the size of the callback.py file to large --- .../model/model_thor.py | 2 +- mindspore/ccsrc/utils/callbacks.cc | 2 +- mindspore/ccsrc/utils/callbacks_ge.cc | 2 +- mindspore/train/callback/__init__.py | 20 +++++++++++++++++++ mindspore/train/{ => callback}/callback.py | 5 +---- mindspore/train/model.py | 2 +- .../models/resnet50/src_thor/model_thor.py | 2 +- tests/ut/python/utils/test_callback.py | 4 ++-- tests/ut/python/utils/test_serialize.py | 2 +- 9 files changed, 29 insertions(+), 12 deletions(-) create mode 100644 mindspore/train/callback/__init__.py rename mindspore/train/{ => callback}/callback.py (99%) diff --git a/example/resnet50_imagenet2012_THOR/model/model_thor.py b/example/resnet50_imagenet2012_THOR/model/model_thor.py index f3418437a36..b8cd27470c2 100644 --- a/example/resnet50_imagenet2012_THOR/model/model_thor.py +++ b/example/resnet50_imagenet2012_THOR/model/model_thor.py @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check from mindspore.train import amp -from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks +from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks from mindspore.train.parallel_utils import ParallelMode from model.dataset_helper import DatasetHelper diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index ad9751c3322..4f21002470d 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace callbacks { -const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback"; +const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback"; const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; const char kSummary[] = "Summary"; diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 151b78d0106..f45e0d5955c 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace callbacks { -const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback"; +const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback.callback"; const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "_checkpoint_cb_for_save_op"; const char PYTHON_FUN_PROCESS_SUMMARY[] = "_summary_cb_for_save_op"; const char kSummary[] = "Summary"; diff --git a/mindspore/train/callback/__init__.py b/mindspore/train/callback/__init__.py new file mode 100644 index 00000000000..1f81f0de414 --- /dev/null +++ b/mindspore/train/callback/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Callback related classes and functions.""" + +from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryStep, CheckpointConfig, RunContext + +__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", + "SummaryStep", "CheckpointConfig", "RunContext"] diff --git a/mindspore/train/callback.py b/mindspore/train/callback/callback.py similarity index 99% rename from mindspore/train/callback.py rename to mindspore/train/callback/callback.py index e691cfd8373..7df804af0f0 100644 --- a/mindspore/train/callback.py +++ b/mindspore/train/callback/callback.py @@ -26,10 +26,7 @@ from mindspore.train._utils import _make_directory from mindspore import log as logger from mindspore._checkparam import check_int_non_negative, check_bool from mindspore.common.tensor import Tensor -from .summary.summary_record import _cache_summary_tensor_data - - -__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"] +from mindspore.train.summary.summary_record import _cache_summary_tensor_data _cur_dir = os.getcwd() diff --git a/mindspore/train/model.py b/mindspore/train/model.py index b711cf675e9..fe76fa900f1 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -19,7 +19,7 @@ from mindspore import log as logger from ..common.tensor import Tensor from ..nn.metrics import get_metrics from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool -from .callback import _InternalCallbackParam, RunContext, _build_callbacks +from .callback.callback import _InternalCallbackParam, RunContext, _build_callbacks from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check diff --git a/tests/st/networks/models/resnet50/src_thor/model_thor.py b/tests/st/networks/models/resnet50/src_thor/model_thor.py index 9bb9639bc8c..c633d913aca 100644 --- a/tests/st/networks/models/resnet50/src_thor/model_thor.py +++ b/tests/st/networks/models/resnet50/src_thor/model_thor.py @@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check from mindspore.train import amp -from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks +from mindspore.train.callback.callback import _InternalCallbackParam, RunContext, _build_callbacks from mindspore.train.parallel_utils import ParallelMode from .dataset_helper import DatasetHelper diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index 42504f29c6e..da564e3f9c2 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -25,8 +25,8 @@ from mindspore.common.api import ms_function from mindspore.common.tensor import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum -from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \ - LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ +from mindspore.train.callback.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, \ + _checkpoint_cb_for_save_op, LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \ _build_callbacks, CheckpointConfig, _set_cur_net diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index b312f9f7d12..f5046bb1ec3 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -28,7 +28,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P -from mindspore.train.callback import _CheckpointManager +from mindspore.train.callback.callback import _CheckpointManager from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ _exec_save_checkpoint, export, _save_graph from ..ut_filter import non_graph_engine