diff --git a/model_zoo/research/cv/dcgan/src/cell.py b/model_zoo/research/cv/dcgan/src/cell.py index 7dea850a974..901416b2f8c 100644 --- a/model_zoo/research/cv/dcgan/src/cell.py +++ b/model_zoo/research/cv/dcgan/src/cell.py @@ -13,21 +13,12 @@ # limitations under the License. # ============================================================================ """dcgan cell""" -import os -import threading -import time import numpy as np -from mindspore import nn, ops, context +from mindspore import nn, ops from mindspore.ops import functional as F from mindspore.common import dtype as mstype -from mindspore.common.initializer import Initializer, _assignment -from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank -from mindspore.train._utils import _make_directory -from mindspore.train.serialization import _save_graph, save_checkpoint -from mindspore.train.callback import Callback -from mindspore.train.callback._callback import set_cur_net -from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \ - _chg_ckpt_file_name_if_same_exist +from mindspore.common.initializer import Initializer +from mindspore.train.callback import ModelCheckpoint class Reshape(nn.Cell): @@ -40,156 +31,35 @@ class Reshape(nn.Cell): class Normal(Initializer): + """normal initializer""" def __init__(self, mean=0.0, sigma=0.01): super(Normal, self).__init__() self.sigma = sigma self.mean = mean def _initialize(self, arr): + """inhert method""" np.random.seed(999) - arr_normal = np.random.normal(self.mean, self.sigma, arr.shape) - _assignment(arr, arr_normal) - - -class ModelCheckpoint(Callback): - """ - The checkpoint callback class. - - It is called to combine with train process and save the model and network parameters after traning. - - Args: - prefix (str): The prefix name of checkpoint files. Default: "CKP". - directory (str): The path of the folder which will be saved in the checkpoint file. Default: None. - config (CheckpointConfig): Checkpoint strategy configuration. Default: None. - - Raises: - ValueError: If the prefix is invalid. - TypeError: If the config is not CheckpointConfig type. - """ - - def __init__(self, prefix='CKP', directory=None, config=None): - super(ModelCheckpoint, self).__init__() - self._latest_ckpt_file_name = "" - self._init_time = time.time() - self._last_time = time.time() - self._last_time_for_keep = time.time() - self._last_triggered_step = 0 - - if _check_file_name_prefix(prefix): - self._prefix = prefix + num = np.random.normal(self.mean, self.sigma, arr.shape) + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) else: - raise ValueError("Prefix {} for checkpoint file name invalid, " - "please check and correct it and then continue.".format(prefix)) + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num - if directory is not None: - self._directory = _make_directory(directory) - else: - self._directory = _cur_dir - if config is None: - self._config = CheckpointConfig() - else: - if not isinstance(config, CheckpointConfig): - raise TypeError("config should be CheckpointConfig type.") - self._config = config - - # get existing checkpoint files - self._manager = CheckpointManager() - self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) - self._graph_saved = False - - def step_end(self, run_context): - """ - Save the checkpoint at the end of step. - - Args: - run_context (RunContext): Context of the train running. - """ - cb_params = run_context.original_args() - # save graph (only once) - if not self._graph_saved: - graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') - _save_graph(cb_params.train_network, graph_file_name) - self._graph_saved = True - self.save_ckpt(cb_params) - - def end(self, run_context): - """ - Save the last checkpoint after training finished. - - Args: - run_context (RunContext): Context of the train running. - """ - cb_params = run_context.original_args() - _to_save_last_ckpt = True - self.save_ckpt(cb_params, _to_save_last_ckpt) - - thread_list = threading.enumerate() - if len(thread_list) > 1: - for thread in thread_list: - if thread.getName() == "asyn_save_ckpt": - thread.join() - - from mindspore.parallel._cell_wrapper import destroy_allgather_cell - destroy_allgather_cell() - - def _check_save_ckpt(self, cb_params, force_to_save): - """Check whether save checkpoint files or not.""" - if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: - if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ - or force_to_save is True: - return True - elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: - self._cur_time = time.time() - if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True: - self._last_time = self._cur_time - return True - - return False +class DcganModelCheckpoint(ModelCheckpoint): + """inherit official ModelCheckpoint""" + def __init__(self, config, directory, prefix='dcgan'): + super().__init__(prefix, directory, config) def save_ckpt(self, cb_params, force_to_save=False): - """Save checkpoint files.""" - if cb_params.cur_step_num == self._last_triggered_step: - return - - save_ckpt = self._check_save_ckpt(cb_params, force_to_save) - step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 - - if save_ckpt: - cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ - + str(step_num_in_epoch) + ".ckpt" - if _is_role_pserver(): - cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file - # update checkpoint file list. - self._manager.update_ckpoint_filelist(self._directory, self._prefix) - # keep checkpoint files number equal max number. - if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: - self._manager.remove_oldest_ckpoint_file() - elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: - self._cur_time_for_keep = time.time() - if (self._cur_time_for_keep - self._last_time_for_keep) \ - < self._config.keep_checkpoint_per_n_minutes * 60: - self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, - self._cur_time_for_keep) - - # generate the new checkpoint file and rename it. - cur_file = os.path.join(self._directory, cur_ckpoint_file) - self._last_time_for_keep = time.time() - self._last_triggered_step = cb_params.cur_step_num - - if context.get_context("enable_ge"): - set_cur_net(cb_params.train_network) - cb_params.train_network.exec_checkpoint_graph() - - save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, - self._config.async_save) - - self._latest_ckpt_file_name = cur_file - - @property - def latest_ckpt_file_name(self): - """Return the latest checkpoint path and file name.""" - return self._latest_ckpt_file_name + """save ckpt""" + super()._save_ckpt(cb_params, force_to_save) class WithLossCellD(nn.Cell): diff --git a/model_zoo/research/cv/dcgan/train.py b/model_zoo/research/cv/dcgan/train.py index 7d64294dde4..01f9195f866 100644 --- a/model_zoo/research/cv/dcgan/train.py +++ b/model_zoo/research/cv/dcgan/train.py @@ -22,14 +22,14 @@ import numpy as np from mindspore import context from mindspore import nn, Tensor -from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam +from mindspore.train.callback import CheckpointConfig from mindspore.context import ParallelMode from mindspore.communication.management import init, get_group_size from src.dataset import create_dataset_imagenet from src.config import dcgan_imagenet_cfg as cfg from src.generator import Generator from src.discriminator import Discriminator -from src.cell import WithLossCellD, WithLossCellG, ModelCheckpoint +from src.cell import WithLossCellD, WithLossCellG, DcganModelCheckpoint from src.dcgan import DCGAN if __name__ == '__main__': @@ -80,9 +80,18 @@ if __name__ == '__main__': # checkpoint save ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=cfg.epoch_size) - ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan') + ckpt_cb = DcganModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan') - cb_params = _InternalCallbackParam() + class CallbackParam(dict): + """Internal callback object's parameters.""" + + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + cb_params = CallbackParam() cb_params.train_network = dcgan cb_params.batch_num = steps_per_epoch cb_params.epoch_num = cfg.epoch_size @@ -105,9 +114,9 @@ if __name__ == '__main__': latent_code = Tensor(data["latent_code"]) netD_loss, netG_loss = dcgan(real_data, latent_code) if i % 50 == 0: - print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/", - cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ", - netG_loss) + time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print("Date time: ", time, "\tepoch: ", epoch, "/", cfg.epoch_size, "\tstep: ", i, + "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ", netG_loss) D_losses.append(netD_loss.asnumpy()) G_losses.append(netG_loss.asnumpy()) cb_params.cur_step_num = cb_params.cur_step_num + 1 diff --git a/model_zoo/research/cv/ntsnet/src/network.py b/model_zoo/research/cv/ntsnet/src/network.py index 87c9bad1601..7cf4080f096 100644 --- a/model_zoo/research/cv/ntsnet/src/network.py +++ b/model_zoo/research/cv/ntsnet/src/network.py @@ -16,20 +16,12 @@ import math import os import time -import threading import numpy as np from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn from mindspore.ops import functional as F from mindspore.ops import operations as P -import mindspore.context as context import mindspore.common.dtype as mstype -from mindspore.train.callback import Callback -from mindspore.train.callback._callback import set_cur_net -from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \ - _chg_ckpt_file_name_if_same_exist -from mindspore.train._utils import _make_directory -from mindspore.train.serialization import save_checkpoint, _save_graph -from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank +from mindspore.train.callback import Callback, ModelCheckpoint from src.resnet import resnet50 from src.config import config @@ -321,7 +313,7 @@ class WithLossCell(nn.Cell): return self._backbone -class ModelCheckpoint(Callback): +class NtsnetModelCheckpoint(ModelCheckpoint): """ The checkpoint callback class. It is called to combine with train process and save the model and network parameters after training. @@ -339,142 +331,17 @@ class ModelCheckpoint(Callback): def __init__(self, prefix='CKP', directory=None, ckconfig=None, device_num=1, device_id=0, args=None, run_modelart=False): - super(ModelCheckpoint, self).__init__() - self._latest_ckpt_file_name = "" - self._init_time = time.time() - self._last_time = time.time() - self._last_time_for_keep = time.time() - self._last_triggered_step = 0 + super(NtsnetModelCheckpoint, self).__init__(prefix, directory, ckconfig) self.run_modelart = run_modelart - if _check_file_name_prefix(prefix): - self._prefix = prefix - else: - raise ValueError("Prefix {} for checkpoint file name invalid, " - "please check and correct it and then continue.".format(prefix)) - if directory is not None: - self._directory = _make_directory(directory) - else: - self._directory = _cur_dir - if ckconfig is None: - self._config = CheckpointConfig() - else: - if not isinstance(ckconfig, CheckpointConfig): - raise TypeError("ckconfig should be CheckpointConfig type.") - self._config = ckconfig - # get existing checkpoint files - self._manager = CheckpointManager() - self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) - self._graph_saved = False - self._need_flush_from_cache = True self.device_num = device_num self.device_id = device_id self.args = args - def step_end(self, run_context): - """ - Save the checkpoint at the end of step. - Args: - run_context (RunContext): Context of the train running. - """ - if _is_role_pserver(): - self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix - cb_params = run_context.original_args() - _make_directory(self._directory) - # save graph (only once) - if not self._graph_saved: - graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') - if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: - os.remove(graph_file_name) - _save_graph(cb_params.train_network, graph_file_name) - self._graph_saved = True - thread_list = threading.enumerate() - for thread in thread_list: - if thread.getName() == "asyn_save_ckpt": - thread.join() - self._save_ckpt(cb_params) - - def end(self, run_context): - """ - Save the last checkpoint after training finished. - Args: - run_context (RunContext): Context of the train running. - """ - cb_params = run_context.original_args() - _to_save_last_ckpt = True - self._save_ckpt(cb_params, _to_save_last_ckpt) - thread_list = threading.enumerate() - for thread in thread_list: - if thread.getName() == "asyn_save_ckpt": - thread.join() - from mindspore.parallel._cell_wrapper import destroy_allgather_cell - destroy_allgather_cell() - - def _check_save_ckpt(self, cb_params, force_to_save): - """Check whether save checkpoint files or not.""" - if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: - if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ - or force_to_save is True: - return True - elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: - self._cur_time = time.time() - if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True: - self._last_time = self._cur_time - return True - return False - def _save_ckpt(self, cb_params, force_to_save=False): - """Save checkpoint files.""" - if cb_params.cur_step_num == self._last_triggered_step: - return - save_ckpt = self._check_save_ckpt(cb_params, force_to_save) - step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) - if save_ckpt: - cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ - + str(step_num_in_epoch) + ".ckpt" - # update checkpoint file list. - self._manager.update_ckpoint_filelist(self._directory, self._prefix) - # keep checkpoint files number equal max number. - if self._config.keep_checkpoint_max and \ - 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: - self._manager.remove_oldest_ckpoint_file() - elif self._config.keep_checkpoint_per_n_minutes and \ - self._config.keep_checkpoint_per_n_minutes > 0: - self._cur_time_for_keep = time.time() - if (self._cur_time_for_keep - self._last_time_for_keep) \ - < self._config.keep_checkpoint_per_n_minutes * 60: - self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, - self._cur_time_for_keep) - # generate the new checkpoint file and rename it. - cur_file = os.path.join(self._directory, cur_ckpoint_file) - self._last_time_for_keep = time.time() - self._last_triggered_step = cb_params.cur_step_num - if context.get_context("enable_ge"): - set_cur_net(cb_params.train_network) - cb_params.train_network.exec_checkpoint_graph() - network = self._config.saved_network if self._config.saved_network is not None \ - else cb_params.train_network - save_checkpoint(network, cur_file, self._config.integrated_save, - self._config.async_save) - self._latest_ckpt_file_name = cur_file - if self.run_modelart and (self.device_num == 1 or self.device_id == 0): - import moxing as mox - mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file)) - - def _flush_from_cache(self, cb_params): - """Flush cache data to host if tensor is cache enable.""" - has_cache_params = False - params = cb_params.train_network.get_parameters() - for param in params: - if param.cache_enable: - has_cache_params = True - Tensor(param).flush_from_cache() - if not has_cache_params: - self._need_flush_from_cache = False - - @property - def latest_ckpt_file_name(self): - """Return the latest checkpoint path and file name.""" - return self._latest_ckpt_file_name + super()._save_ckpt(cb_params, force_to_save) + if self.run_modelart and (self.device_num == 1 or self.device_id == 0): + import moxing as mox + mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file)) class LossCallBack(Callback): diff --git a/model_zoo/research/cv/ntsnet/train.py b/model_zoo/research/cv/ntsnet/train.py index 117dc7e00a9..87af3d5d9c2 100644 --- a/model_zoo/research/cv/ntsnet/train.py +++ b/model_zoo/research/cv/ntsnet/train.py @@ -24,7 +24,7 @@ from mindspore.communication.management import init, get_rank, get_group_size from src.config import config from src.dataset import create_dataset_train from src.lr_generator import get_lr -from src.network import NTS_NET, WithLossCell, LossCallBack, ModelCheckpoint +from src.network import NTS_NET, WithLossCell, LossCallBack, NtsnetModelCheckpoint parser = argparse.ArgumentParser(description='ntsnet train running') parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.") @@ -113,8 +113,9 @@ if __name__ == '__main__': keep_checkpoint_max=config.keep_checkpoint_max) save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/") - ckpoint_cb = ModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig, - device_num=device_num, device_id=device_id, args=args, run_modelart=run_modelart) + ckpoint_cb = NtsnetModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig, + device_num=device_num, device_id=device_id, args=args, + run_modelart=run_modelart) cb += [ckpoint_cb] model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False)