fix issues on ntsnet and dcgan

This commit is contained in:
gengdongjie 2021-08-13 17:24:18 +08:00
parent a480d07dd2
commit ac58821fb0
4 changed files with 47 additions and 300 deletions

View File

@ -13,21 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""dcgan cell""" """dcgan cell"""
import os
import threading
import time
import numpy as np import numpy as np
from mindspore import nn, ops, context from mindspore import nn, ops
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import Initializer, _assignment from mindspore.common.initializer import Initializer
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank from mindspore.train.callback import ModelCheckpoint
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _save_graph, save_checkpoint
from mindspore.train.callback import Callback
from mindspore.train.callback._callback import set_cur_net
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
_chg_ckpt_file_name_if_same_exist
class Reshape(nn.Cell): class Reshape(nn.Cell):
@ -40,156 +31,35 @@ class Reshape(nn.Cell):
class Normal(Initializer): class Normal(Initializer):
"""normal initializer"""
def __init__(self, mean=0.0, sigma=0.01): def __init__(self, mean=0.0, sigma=0.01):
super(Normal, self).__init__() super(Normal, self).__init__()
self.sigma = sigma self.sigma = sigma
self.mean = mean self.mean = mean
def _initialize(self, arr): def _initialize(self, arr):
"""inhert method"""
np.random.seed(999) np.random.seed(999)
arr_normal = np.random.normal(self.mean, self.sigma, arr.shape) num = np.random.normal(self.mean, self.sigma, arr.shape)
_assignment(arr, arr_normal) if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
class ModelCheckpoint(Callback): arr = arr.reshape(())
"""
The checkpoint callback class.
It is called to combine with train process and save the model and network parameters after traning.
Args:
prefix (str): The prefix name of checkpoint files. Default: "CKP".
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
Raises:
ValueError: If the prefix is invalid.
TypeError: If the config is not CheckpointConfig type.
"""
def __init__(self, prefix='CKP', directory=None, config=None):
super(ModelCheckpoint, self).__init__()
self._latest_ckpt_file_name = ""
self._init_time = time.time()
self._last_time = time.time()
self._last_time_for_keep = time.time()
self._last_triggered_step = 0
if _check_file_name_prefix(prefix):
self._prefix = prefix
else: else:
raise ValueError("Prefix {} for checkpoint file name invalid, " if isinstance(num, np.ndarray):
"please check and correct it and then continue.".format(prefix)) arr[:] = num[:]
else:
arr[:] = num
if directory is not None:
self._directory = _make_directory(directory)
else:
self._directory = _cur_dir
if config is None: class DcganModelCheckpoint(ModelCheckpoint):
self._config = CheckpointConfig() """inherit official ModelCheckpoint"""
else: def __init__(self, config, directory, prefix='dcgan'):
if not isinstance(config, CheckpointConfig): super().__init__(prefix, directory, config)
raise TypeError("config should be CheckpointConfig type.")
self._config = config
# get existing checkpoint files
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._graph_saved = False
def step_end(self, run_context):
"""
Save the checkpoint at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
# save graph (only once)
if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
self.save_ckpt(cb_params)
def end(self, run_context):
"""
Save the last checkpoint after training finished.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
_to_save_last_ckpt = True
self.save_ckpt(cb_params, _to_save_last_ckpt)
thread_list = threading.enumerate()
if len(thread_list) > 1:
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not."""
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True:
return True
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
self._cur_time = time.time()
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
self._last_time = self._cur_time
return True
return False
def save_ckpt(self, cb_params, force_to_save=False): def save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files.""" """save ckpt"""
if cb_params.cur_step_num == self._last_triggered_step: super()._save_ckpt(cb_params, force_to_save)
return
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
if _is_role_pserver():
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
self._cur_time_for_keep = time.time()
if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep)
# generate the new checkpoint file and rename it.
cur_file = os.path.join(self._directory, cur_ckpoint_file)
self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num
if context.get_context("enable_ge"):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)
self._latest_ckpt_file_name = cur_file
@property
def latest_ckpt_file_name(self):
"""Return the latest checkpoint path and file name."""
return self._latest_ckpt_file_name
class WithLossCellD(nn.Cell): class WithLossCellD(nn.Cell):

View File

@ -22,14 +22,14 @@ import numpy as np
from mindspore import context from mindspore import context
from mindspore import nn, Tensor from mindspore import nn, Tensor
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam from mindspore.train.callback import CheckpointConfig
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_group_size from mindspore.communication.management import init, get_group_size
from src.dataset import create_dataset_imagenet from src.dataset import create_dataset_imagenet
from src.config import dcgan_imagenet_cfg as cfg from src.config import dcgan_imagenet_cfg as cfg
from src.generator import Generator from src.generator import Generator
from src.discriminator import Discriminator from src.discriminator import Discriminator
from src.cell import WithLossCellD, WithLossCellG, ModelCheckpoint from src.cell import WithLossCellD, WithLossCellG, DcganModelCheckpoint
from src.dcgan import DCGAN from src.dcgan import DCGAN
if __name__ == '__main__': if __name__ == '__main__':
@ -80,9 +80,18 @@ if __name__ == '__main__':
# checkpoint save # checkpoint save
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
keep_checkpoint_max=cfg.epoch_size) keep_checkpoint_max=cfg.epoch_size)
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan') ckpt_cb = DcganModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
cb_params = _InternalCallbackParam() class CallbackParam(dict):
"""Internal callback object's parameters."""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
cb_params = CallbackParam()
cb_params.train_network = dcgan cb_params.train_network = dcgan
cb_params.batch_num = steps_per_epoch cb_params.batch_num = steps_per_epoch
cb_params.epoch_num = cfg.epoch_size cb_params.epoch_num = cfg.epoch_size
@ -105,9 +114,9 @@ if __name__ == '__main__':
latent_code = Tensor(data["latent_code"]) latent_code = Tensor(data["latent_code"])
netD_loss, netG_loss = dcgan(real_data, latent_code) netD_loss, netG_loss = dcgan(real_data, latent_code)
if i % 50 == 0: if i % 50 == 0:
print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/", time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ", print("Date time: ", time, "\tepoch: ", epoch, "/", cfg.epoch_size, "\tstep: ", i,
netG_loss) "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ", netG_loss)
D_losses.append(netD_loss.asnumpy()) D_losses.append(netD_loss.asnumpy())
G_losses.append(netG_loss.asnumpy()) G_losses.append(netG_loss.asnumpy())
cb_params.cur_step_num = cb_params.cur_step_num + 1 cb_params.cur_step_num = cb_params.cur_step_num + 1

View File

@ -16,20 +16,12 @@
import math import math
import os import os
import time import time
import threading
import numpy as np import numpy as np
from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.context as context
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.train.callback import Callback from mindspore.train.callback import Callback, ModelCheckpoint
from mindspore.train.callback._callback import set_cur_net
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
_chg_ckpt_file_name_if_same_exist
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from src.resnet import resnet50 from src.resnet import resnet50
from src.config import config from src.config import config
@ -321,7 +313,7 @@ class WithLossCell(nn.Cell):
return self._backbone return self._backbone
class ModelCheckpoint(Callback): class NtsnetModelCheckpoint(ModelCheckpoint):
""" """
The checkpoint callback class. The checkpoint callback class.
It is called to combine with train process and save the model and network parameters after training. It is called to combine with train process and save the model and network parameters after training.
@ -339,142 +331,17 @@ class ModelCheckpoint(Callback):
def __init__(self, prefix='CKP', directory=None, ckconfig=None, def __init__(self, prefix='CKP', directory=None, ckconfig=None,
device_num=1, device_id=0, args=None, run_modelart=False): device_num=1, device_id=0, args=None, run_modelart=False):
super(ModelCheckpoint, self).__init__() super(NtsnetModelCheckpoint, self).__init__(prefix, directory, ckconfig)
self._latest_ckpt_file_name = ""
self._init_time = time.time()
self._last_time = time.time()
self._last_time_for_keep = time.time()
self._last_triggered_step = 0
self.run_modelart = run_modelart self.run_modelart = run_modelart
if _check_file_name_prefix(prefix):
self._prefix = prefix
else:
raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix))
if directory is not None:
self._directory = _make_directory(directory)
else:
self._directory = _cur_dir
if ckconfig is None:
self._config = CheckpointConfig()
else:
if not isinstance(ckconfig, CheckpointConfig):
raise TypeError("ckconfig should be CheckpointConfig type.")
self._config = ckconfig
# get existing checkpoint files
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._graph_saved = False
self._need_flush_from_cache = True
self.device_num = device_num self.device_num = device_num
self.device_id = device_id self.device_id = device_id
self.args = args self.args = args
def step_end(self, run_context):
"""
Save the checkpoint at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args()
_make_directory(self._directory)
# save graph (only once)
if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
os.remove(graph_file_name)
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
thread_list = threading.enumerate()
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
self._save_ckpt(cb_params)
def end(self, run_context):
"""
Save the last checkpoint after training finished.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
_to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt)
thread_list = threading.enumerate()
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not."""
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True:
return True
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
self._cur_time = time.time()
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
self._last_time = self._cur_time
return True
return False
def _save_ckpt(self, cb_params, force_to_save=False): def _save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files.""" super()._save_ckpt(cb_params, force_to_save)
if cb_params.cur_step_num == self._last_triggered_step: if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
return import moxing as mox
save_ckpt = self._check_save_ckpt(cb_params, force_to_save) mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and \
0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and \
self._config.keep_checkpoint_per_n_minutes > 0:
self._cur_time_for_keep = time.time()
if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep)
# generate the new checkpoint file and rename it.
cur_file = os.path.join(self._directory, cur_ckpoint_file)
self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num
if context.get_context("enable_ge"):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
network = self._config.saved_network if self._config.saved_network is not None \
else cb_params.train_network
save_checkpoint(network, cur_file, self._config.integrated_save,
self._config.async_save)
self._latest_ckpt_file_name = cur_file
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
import moxing as mox
mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable."""
has_cache_params = False
params = cb_params.train_network.get_parameters()
for param in params:
if param.cache_enable:
has_cache_params = True
Tensor(param).flush_from_cache()
if not has_cache_params:
self._need_flush_from_cache = False
@property
def latest_ckpt_file_name(self):
"""Return the latest checkpoint path and file name."""
return self._latest_ckpt_file_name
class LossCallBack(Callback): class LossCallBack(Callback):

View File

@ -24,7 +24,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
from src.config import config from src.config import config
from src.dataset import create_dataset_train from src.dataset import create_dataset_train
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.network import NTS_NET, WithLossCell, LossCallBack, ModelCheckpoint from src.network import NTS_NET, WithLossCell, LossCallBack, NtsnetModelCheckpoint
parser = argparse.ArgumentParser(description='ntsnet train running') parser = argparse.ArgumentParser(description='ntsnet train running')
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.") parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
@ -113,8 +113,9 @@ if __name__ == '__main__':
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/") save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/")
ckpoint_cb = ModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig, ckpoint_cb = NtsnetModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig,
device_num=device_num, device_id=device_id, args=args, run_modelart=run_modelart) device_num=device_num, device_id=device_id, args=args,
run_modelart=run_modelart)
cb += [ckpoint_cb] cb += [ckpoint_cb]
model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False) model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False)