forked from mindspore-Ecosystem/mindspore
!21768 fix issues for ntsnet and dcgan
Merge pull request !21768 from gengdongjie/code_docs_fix_issues
This commit is contained in:
commit
3232db907f
|
@ -13,21 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan cell"""
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import nn, ops, context
|
||||
from mindspore import nn, ops
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import Initializer, _assignment
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import _save_graph, save_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.callback._callback import set_cur_net
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
|
||||
_chg_ckpt_file_name_if_same_exist
|
||||
from mindspore.common.initializer import Initializer
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
|
||||
|
||||
class Reshape(nn.Cell):
|
||||
|
@ -40,156 +31,35 @@ class Reshape(nn.Cell):
|
|||
|
||||
|
||||
class Normal(Initializer):
|
||||
"""normal initializer"""
|
||||
def __init__(self, mean=0.0, sigma=0.01):
|
||||
super(Normal, self).__init__()
|
||||
self.sigma = sigma
|
||||
self.mean = mean
|
||||
|
||||
def _initialize(self, arr):
|
||||
"""inhert method"""
|
||||
np.random.seed(999)
|
||||
arr_normal = np.random.normal(self.mean, self.sigma, arr.shape)
|
||||
_assignment(arr, arr_normal)
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
"""
|
||||
The checkpoint callback class.
|
||||
|
||||
It is called to combine with train process and save the model and network parameters after traning.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix name of checkpoint files. Default: "CKP".
|
||||
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
|
||||
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prefix is invalid.
|
||||
TypeError: If the config is not CheckpointConfig type.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix='CKP', directory=None, config=None):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self._latest_ckpt_file_name = ""
|
||||
self._init_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = 0
|
||||
|
||||
if _check_file_name_prefix(prefix):
|
||||
self._prefix = prefix
|
||||
num = np.random.normal(self.mean, self.sigma, arr.shape)
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
||||
"please check and correct it and then continue.".format(prefix))
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
|
||||
if directory is not None:
|
||||
self._directory = _make_directory(directory)
|
||||
else:
|
||||
self._directory = _cur_dir
|
||||
|
||||
if config is None:
|
||||
self._config = CheckpointConfig()
|
||||
else:
|
||||
if not isinstance(config, CheckpointConfig):
|
||||
raise TypeError("config should be CheckpointConfig type.")
|
||||
self._config = config
|
||||
|
||||
# get existing checkpoint files
|
||||
self._manager = CheckpointManager()
|
||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||
self._graph_saved = False
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save the checkpoint at the end of step.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
# save graph (only once)
|
||||
if not self._graph_saved:
|
||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||
_save_graph(cb_params.train_network, graph_file_name)
|
||||
self._graph_saved = True
|
||||
self.save_ckpt(cb_params)
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Save the last checkpoint after training finished.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
_to_save_last_ckpt = True
|
||||
self.save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
|
||||
thread_list = threading.enumerate()
|
||||
if len(thread_list) > 1:
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
||||
"""Check whether save checkpoint files or not."""
|
||||
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
|
||||
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
|
||||
or force_to_save is True:
|
||||
return True
|
||||
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
|
||||
self._cur_time = time.time()
|
||||
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
|
||||
self._last_time = self._cur_time
|
||||
return True
|
||||
|
||||
return False
|
||||
class DcganModelCheckpoint(ModelCheckpoint):
|
||||
"""inherit official ModelCheckpoint"""
|
||||
def __init__(self, config, directory, prefix='dcgan'):
|
||||
super().__init__(prefix, directory, config)
|
||||
|
||||
def save_ckpt(self, cb_params, force_to_save=False):
|
||||
"""Save checkpoint files."""
|
||||
if cb_params.cur_step_num == self._last_triggered_step:
|
||||
return
|
||||
|
||||
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
||||
step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if save_ckpt:
|
||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
||||
+ str(step_num_in_epoch) + ".ckpt"
|
||||
if _is_role_pserver():
|
||||
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
|
||||
# update checkpoint file list.
|
||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
||||
# keep checkpoint files number equal max number.
|
||||
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
|
||||
self._manager.remove_oldest_ckpoint_file()
|
||||
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
|
||||
self._cur_time_for_keep = time.time()
|
||||
if (self._cur_time_for_keep - self._last_time_for_keep) \
|
||||
< self._config.keep_checkpoint_per_n_minutes * 60:
|
||||
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
|
||||
self._cur_time_for_keep)
|
||||
|
||||
# generate the new checkpoint file and rename it.
|
||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = cb_params.cur_step_num
|
||||
|
||||
if context.get_context("enable_ge"):
|
||||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
||||
@property
|
||||
def latest_ckpt_file_name(self):
|
||||
"""Return the latest checkpoint path and file name."""
|
||||
return self._latest_ckpt_file_name
|
||||
"""save ckpt"""
|
||||
super()._save_ckpt(cb_params, force_to_save)
|
||||
|
||||
|
||||
class WithLossCellD(nn.Cell):
|
||||
|
|
|
@ -22,14 +22,14 @@ import numpy as np
|
|||
|
||||
from mindspore import context
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam
|
||||
from mindspore.train.callback import CheckpointConfig
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from src.dataset import create_dataset_imagenet
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
from src.generator import Generator
|
||||
from src.discriminator import Discriminator
|
||||
from src.cell import WithLossCellD, WithLossCellG, ModelCheckpoint
|
||||
from src.cell import WithLossCellD, WithLossCellG, DcganModelCheckpoint
|
||||
from src.dcgan import DCGAN
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -80,9 +80,18 @@ if __name__ == '__main__':
|
|||
# checkpoint save
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
|
||||
keep_checkpoint_max=cfg.epoch_size)
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
|
||||
ckpt_cb = DcganModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
class CallbackParam(dict):
|
||||
"""Internal callback object's parameters."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
cb_params = CallbackParam()
|
||||
cb_params.train_network = dcgan
|
||||
cb_params.batch_num = steps_per_epoch
|
||||
cb_params.epoch_num = cfg.epoch_size
|
||||
|
@ -105,9 +114,9 @@ if __name__ == '__main__':
|
|||
latent_code = Tensor(data["latent_code"])
|
||||
netD_loss, netG_loss = dcgan(real_data, latent_code)
|
||||
if i % 50 == 0:
|
||||
print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/",
|
||||
cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ",
|
||||
netG_loss)
|
||||
time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
print("Date time: ", time, "\tepoch: ", epoch, "/", cfg.epoch_size, "\tstep: ", i,
|
||||
"/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ", netG_loss)
|
||||
D_losses.append(netD_loss.asnumpy())
|
||||
G_losses.append(netG_loss.asnumpy())
|
||||
cb_params.cur_step_num = cb_params.cur_step_num + 1
|
||||
|
|
|
@ -16,20 +16,12 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.callback._callback import set_cur_net
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
|
||||
_chg_ckpt_file_name_if_same_exist
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint
|
||||
from src.resnet import resnet50
|
||||
from src.config import config
|
||||
|
||||
|
@ -321,7 +313,7 @@ class WithLossCell(nn.Cell):
|
|||
return self._backbone
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
class NtsnetModelCheckpoint(ModelCheckpoint):
|
||||
"""
|
||||
The checkpoint callback class.
|
||||
It is called to combine with train process and save the model and network parameters after training.
|
||||
|
@ -339,142 +331,17 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
def __init__(self, prefix='CKP', directory=None, ckconfig=None,
|
||||
device_num=1, device_id=0, args=None, run_modelart=False):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self._latest_ckpt_file_name = ""
|
||||
self._init_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = 0
|
||||
super(NtsnetModelCheckpoint, self).__init__(prefix, directory, ckconfig)
|
||||
self.run_modelart = run_modelart
|
||||
if _check_file_name_prefix(prefix):
|
||||
self._prefix = prefix
|
||||
else:
|
||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
||||
"please check and correct it and then continue.".format(prefix))
|
||||
if directory is not None:
|
||||
self._directory = _make_directory(directory)
|
||||
else:
|
||||
self._directory = _cur_dir
|
||||
if ckconfig is None:
|
||||
self._config = CheckpointConfig()
|
||||
else:
|
||||
if not isinstance(ckconfig, CheckpointConfig):
|
||||
raise TypeError("ckconfig should be CheckpointConfig type.")
|
||||
self._config = ckconfig
|
||||
# get existing checkpoint files
|
||||
self._manager = CheckpointManager()
|
||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||
self._graph_saved = False
|
||||
self._need_flush_from_cache = True
|
||||
self.device_num = device_num
|
||||
self.device_id = device_id
|
||||
self.args = args
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save the checkpoint at the end of step.
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
if _is_role_pserver():
|
||||
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
|
||||
cb_params = run_context.original_args()
|
||||
_make_directory(self._directory)
|
||||
# save graph (only once)
|
||||
if not self._graph_saved:
|
||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
|
||||
os.remove(graph_file_name)
|
||||
_save_graph(cb_params.train_network, graph_file_name)
|
||||
self._graph_saved = True
|
||||
thread_list = threading.enumerate()
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
self._save_ckpt(cb_params)
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Save the last checkpoint after training finished.
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
_to_save_last_ckpt = True
|
||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
thread_list = threading.enumerate()
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
||||
"""Check whether save checkpoint files or not."""
|
||||
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
|
||||
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
|
||||
or force_to_save is True:
|
||||
return True
|
||||
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
|
||||
self._cur_time = time.time()
|
||||
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
|
||||
self._last_time = self._cur_time
|
||||
return True
|
||||
return False
|
||||
|
||||
def _save_ckpt(self, cb_params, force_to_save=False):
|
||||
"""Save checkpoint files."""
|
||||
if cb_params.cur_step_num == self._last_triggered_step:
|
||||
return
|
||||
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
||||
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
||||
if save_ckpt:
|
||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
||||
+ str(step_num_in_epoch) + ".ckpt"
|
||||
# update checkpoint file list.
|
||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
||||
# keep checkpoint files number equal max number.
|
||||
if self._config.keep_checkpoint_max and \
|
||||
0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
|
||||
self._manager.remove_oldest_ckpoint_file()
|
||||
elif self._config.keep_checkpoint_per_n_minutes and \
|
||||
self._config.keep_checkpoint_per_n_minutes > 0:
|
||||
self._cur_time_for_keep = time.time()
|
||||
if (self._cur_time_for_keep - self._last_time_for_keep) \
|
||||
< self._config.keep_checkpoint_per_n_minutes * 60:
|
||||
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
|
||||
self._cur_time_for_keep)
|
||||
# generate the new checkpoint file and rename it.
|
||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = cb_params.cur_step_num
|
||||
if context.get_context("enable_ge"):
|
||||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
network = self._config.saved_network if self._config.saved_network is not None \
|
||||
else cb_params.train_network
|
||||
save_checkpoint(network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
|
||||
|
||||
def _flush_from_cache(self, cb_params):
|
||||
"""Flush cache data to host if tensor is cache enable."""
|
||||
has_cache_params = False
|
||||
params = cb_params.train_network.get_parameters()
|
||||
for param in params:
|
||||
if param.cache_enable:
|
||||
has_cache_params = True
|
||||
Tensor(param).flush_from_cache()
|
||||
if not has_cache_params:
|
||||
self._need_flush_from_cache = False
|
||||
|
||||
@property
|
||||
def latest_ckpt_file_name(self):
|
||||
"""Return the latest checkpoint path and file name."""
|
||||
return self._latest_ckpt_file_name
|
||||
super()._save_ckpt(cb_params, force_to_save)
|
||||
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
|||
from src.config import config
|
||||
from src.dataset import create_dataset_train
|
||||
from src.lr_generator import get_lr
|
||||
from src.network import NTS_NET, WithLossCell, LossCallBack, ModelCheckpoint
|
||||
from src.network import NTS_NET, WithLossCell, LossCallBack, NtsnetModelCheckpoint
|
||||
|
||||
parser = argparse.ArgumentParser(description='ntsnet train running')
|
||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||
|
@ -113,8 +113,9 @@ if __name__ == '__main__':
|
|||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/")
|
||||
|
||||
ckpoint_cb = ModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig,
|
||||
device_num=device_num, device_id=device_id, args=args, run_modelart=run_modelart)
|
||||
ckpoint_cb = NtsnetModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig,
|
||||
device_num=device_num, device_id=device_id, args=args,
|
||||
run_modelart=run_modelart)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False)
|
||||
|
|
Loading…
Reference in New Issue