forked from mindspore-Ecosystem/mindspore
fix issues on ntsnet and dcgan
This commit is contained in:
parent
a480d07dd2
commit
ac58821fb0
|
@ -13,21 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""dcgan cell"""
|
"""dcgan cell"""
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import nn, ops, context
|
from mindspore import nn, ops
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.initializer import Initializer, _assignment
|
from mindspore.common.initializer import Initializer
|
||||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
from mindspore.train.callback import ModelCheckpoint
|
||||||
from mindspore.train._utils import _make_directory
|
|
||||||
from mindspore.train.serialization import _save_graph, save_checkpoint
|
|
||||||
from mindspore.train.callback import Callback
|
|
||||||
from mindspore.train.callback._callback import set_cur_net
|
|
||||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
|
|
||||||
_chg_ckpt_file_name_if_same_exist
|
|
||||||
|
|
||||||
|
|
||||||
class Reshape(nn.Cell):
|
class Reshape(nn.Cell):
|
||||||
|
@ -40,156 +31,35 @@ class Reshape(nn.Cell):
|
||||||
|
|
||||||
|
|
||||||
class Normal(Initializer):
|
class Normal(Initializer):
|
||||||
|
"""normal initializer"""
|
||||||
def __init__(self, mean=0.0, sigma=0.01):
|
def __init__(self, mean=0.0, sigma=0.01):
|
||||||
super(Normal, self).__init__()
|
super(Normal, self).__init__()
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
|
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
|
"""inhert method"""
|
||||||
np.random.seed(999)
|
np.random.seed(999)
|
||||||
arr_normal = np.random.normal(self.mean, self.sigma, arr.shape)
|
num = np.random.normal(self.mean, self.sigma, arr.shape)
|
||||||
_assignment(arr, arr_normal)
|
if arr.shape == ():
|
||||||
|
arr = arr.reshape((1))
|
||||||
|
arr[:] = num
|
||||||
class ModelCheckpoint(Callback):
|
arr = arr.reshape(())
|
||||||
"""
|
|
||||||
The checkpoint callback class.
|
|
||||||
|
|
||||||
It is called to combine with train process and save the model and network parameters after traning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefix (str): The prefix name of checkpoint files. Default: "CKP".
|
|
||||||
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
|
|
||||||
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the prefix is invalid.
|
|
||||||
TypeError: If the config is not CheckpointConfig type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefix='CKP', directory=None, config=None):
|
|
||||||
super(ModelCheckpoint, self).__init__()
|
|
||||||
self._latest_ckpt_file_name = ""
|
|
||||||
self._init_time = time.time()
|
|
||||||
self._last_time = time.time()
|
|
||||||
self._last_time_for_keep = time.time()
|
|
||||||
self._last_triggered_step = 0
|
|
||||||
|
|
||||||
if _check_file_name_prefix(prefix):
|
|
||||||
self._prefix = prefix
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
if isinstance(num, np.ndarray):
|
||||||
"please check and correct it and then continue.".format(prefix))
|
arr[:] = num[:]
|
||||||
|
else:
|
||||||
|
arr[:] = num
|
||||||
|
|
||||||
if directory is not None:
|
|
||||||
self._directory = _make_directory(directory)
|
|
||||||
else:
|
|
||||||
self._directory = _cur_dir
|
|
||||||
|
|
||||||
if config is None:
|
class DcganModelCheckpoint(ModelCheckpoint):
|
||||||
self._config = CheckpointConfig()
|
"""inherit official ModelCheckpoint"""
|
||||||
else:
|
def __init__(self, config, directory, prefix='dcgan'):
|
||||||
if not isinstance(config, CheckpointConfig):
|
super().__init__(prefix, directory, config)
|
||||||
raise TypeError("config should be CheckpointConfig type.")
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
# get existing checkpoint files
|
|
||||||
self._manager = CheckpointManager()
|
|
||||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
|
||||||
self._graph_saved = False
|
|
||||||
|
|
||||||
def step_end(self, run_context):
|
|
||||||
"""
|
|
||||||
Save the checkpoint at the end of step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Context of the train running.
|
|
||||||
"""
|
|
||||||
cb_params = run_context.original_args()
|
|
||||||
# save graph (only once)
|
|
||||||
if not self._graph_saved:
|
|
||||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
|
||||||
_save_graph(cb_params.train_network, graph_file_name)
|
|
||||||
self._graph_saved = True
|
|
||||||
self.save_ckpt(cb_params)
|
|
||||||
|
|
||||||
def end(self, run_context):
|
|
||||||
"""
|
|
||||||
Save the last checkpoint after training finished.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Context of the train running.
|
|
||||||
"""
|
|
||||||
cb_params = run_context.original_args()
|
|
||||||
_to_save_last_ckpt = True
|
|
||||||
self.save_ckpt(cb_params, _to_save_last_ckpt)
|
|
||||||
|
|
||||||
thread_list = threading.enumerate()
|
|
||||||
if len(thread_list) > 1:
|
|
||||||
for thread in thread_list:
|
|
||||||
if thread.getName() == "asyn_save_ckpt":
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
||||||
destroy_allgather_cell()
|
|
||||||
|
|
||||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
|
||||||
"""Check whether save checkpoint files or not."""
|
|
||||||
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
|
|
||||||
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
|
|
||||||
or force_to_save is True:
|
|
||||||
return True
|
|
||||||
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
|
|
||||||
self._cur_time = time.time()
|
|
||||||
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
|
|
||||||
self._last_time = self._cur_time
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def save_ckpt(self, cb_params, force_to_save=False):
|
def save_ckpt(self, cb_params, force_to_save=False):
|
||||||
"""Save checkpoint files."""
|
"""save ckpt"""
|
||||||
if cb_params.cur_step_num == self._last_triggered_step:
|
super()._save_ckpt(cb_params, force_to_save)
|
||||||
return
|
|
||||||
|
|
||||||
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
|
||||||
step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
|
||||||
|
|
||||||
if save_ckpt:
|
|
||||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
|
||||||
+ str(step_num_in_epoch) + ".ckpt"
|
|
||||||
if _is_role_pserver():
|
|
||||||
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
|
|
||||||
# update checkpoint file list.
|
|
||||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
|
||||||
# keep checkpoint files number equal max number.
|
|
||||||
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
|
|
||||||
self._manager.remove_oldest_ckpoint_file()
|
|
||||||
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
|
|
||||||
self._cur_time_for_keep = time.time()
|
|
||||||
if (self._cur_time_for_keep - self._last_time_for_keep) \
|
|
||||||
< self._config.keep_checkpoint_per_n_minutes * 60:
|
|
||||||
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
|
|
||||||
self._cur_time_for_keep)
|
|
||||||
|
|
||||||
# generate the new checkpoint file and rename it.
|
|
||||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
|
||||||
self._last_time_for_keep = time.time()
|
|
||||||
self._last_triggered_step = cb_params.cur_step_num
|
|
||||||
|
|
||||||
if context.get_context("enable_ge"):
|
|
||||||
set_cur_net(cb_params.train_network)
|
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
|
||||||
|
|
||||||
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
|
||||||
self._config.async_save)
|
|
||||||
|
|
||||||
self._latest_ckpt_file_name = cur_file
|
|
||||||
|
|
||||||
@property
|
|
||||||
def latest_ckpt_file_name(self):
|
|
||||||
"""Return the latest checkpoint path and file name."""
|
|
||||||
return self._latest_ckpt_file_name
|
|
||||||
|
|
||||||
|
|
||||||
class WithLossCellD(nn.Cell):
|
class WithLossCellD(nn.Cell):
|
||||||
|
|
|
@ -22,14 +22,14 @@ import numpy as np
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import nn, Tensor
|
from mindspore import nn, Tensor
|
||||||
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam
|
from mindspore.train.callback import CheckpointConfig
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import init, get_group_size
|
from mindspore.communication.management import init, get_group_size
|
||||||
from src.dataset import create_dataset_imagenet
|
from src.dataset import create_dataset_imagenet
|
||||||
from src.config import dcgan_imagenet_cfg as cfg
|
from src.config import dcgan_imagenet_cfg as cfg
|
||||||
from src.generator import Generator
|
from src.generator import Generator
|
||||||
from src.discriminator import Discriminator
|
from src.discriminator import Discriminator
|
||||||
from src.cell import WithLossCellD, WithLossCellG, ModelCheckpoint
|
from src.cell import WithLossCellD, WithLossCellG, DcganModelCheckpoint
|
||||||
from src.dcgan import DCGAN
|
from src.dcgan import DCGAN
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -80,9 +80,18 @@ if __name__ == '__main__':
|
||||||
# checkpoint save
|
# checkpoint save
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
|
||||||
keep_checkpoint_max=cfg.epoch_size)
|
keep_checkpoint_max=cfg.epoch_size)
|
||||||
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
|
ckpt_cb = DcganModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
|
||||||
|
|
||||||
cb_params = _InternalCallbackParam()
|
class CallbackParam(dict):
|
||||||
|
"""Internal callback object's parameters."""
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return self[key]
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
cb_params = CallbackParam()
|
||||||
cb_params.train_network = dcgan
|
cb_params.train_network = dcgan
|
||||||
cb_params.batch_num = steps_per_epoch
|
cb_params.batch_num = steps_per_epoch
|
||||||
cb_params.epoch_num = cfg.epoch_size
|
cb_params.epoch_num = cfg.epoch_size
|
||||||
|
@ -105,9 +114,9 @@ if __name__ == '__main__':
|
||||||
latent_code = Tensor(data["latent_code"])
|
latent_code = Tensor(data["latent_code"])
|
||||||
netD_loss, netG_loss = dcgan(real_data, latent_code)
|
netD_loss, netG_loss = dcgan(real_data, latent_code)
|
||||||
if i % 50 == 0:
|
if i % 50 == 0:
|
||||||
print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/",
|
time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ",
|
print("Date time: ", time, "\tepoch: ", epoch, "/", cfg.epoch_size, "\tstep: ", i,
|
||||||
netG_loss)
|
"/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ", netG_loss)
|
||||||
D_losses.append(netD_loss.asnumpy())
|
D_losses.append(netD_loss.asnumpy())
|
||||||
G_losses.append(netG_loss.asnumpy())
|
G_losses.append(netG_loss.asnumpy())
|
||||||
cb_params.cur_step_num = cb_params.cur_step_num + 1
|
cb_params.cur_step_num = cb_params.cur_step_num + 1
|
||||||
|
|
|
@ -16,20 +16,12 @@
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import threading
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn
|
from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
import mindspore.context as context
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.train.callback import Callback
|
from mindspore.train.callback import Callback, ModelCheckpoint
|
||||||
from mindspore.train.callback._callback import set_cur_net
|
|
||||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
|
|
||||||
_chg_ckpt_file_name_if_same_exist
|
|
||||||
from mindspore.train._utils import _make_directory
|
|
||||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
|
||||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
|
||||||
from src.resnet import resnet50
|
from src.resnet import resnet50
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
|
@ -321,7 +313,7 @@ class WithLossCell(nn.Cell):
|
||||||
return self._backbone
|
return self._backbone
|
||||||
|
|
||||||
|
|
||||||
class ModelCheckpoint(Callback):
|
class NtsnetModelCheckpoint(ModelCheckpoint):
|
||||||
"""
|
"""
|
||||||
The checkpoint callback class.
|
The checkpoint callback class.
|
||||||
It is called to combine with train process and save the model and network parameters after training.
|
It is called to combine with train process and save the model and network parameters after training.
|
||||||
|
@ -339,142 +331,17 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
def __init__(self, prefix='CKP', directory=None, ckconfig=None,
|
def __init__(self, prefix='CKP', directory=None, ckconfig=None,
|
||||||
device_num=1, device_id=0, args=None, run_modelart=False):
|
device_num=1, device_id=0, args=None, run_modelart=False):
|
||||||
super(ModelCheckpoint, self).__init__()
|
super(NtsnetModelCheckpoint, self).__init__(prefix, directory, ckconfig)
|
||||||
self._latest_ckpt_file_name = ""
|
|
||||||
self._init_time = time.time()
|
|
||||||
self._last_time = time.time()
|
|
||||||
self._last_time_for_keep = time.time()
|
|
||||||
self._last_triggered_step = 0
|
|
||||||
self.run_modelart = run_modelart
|
self.run_modelart = run_modelart
|
||||||
if _check_file_name_prefix(prefix):
|
|
||||||
self._prefix = prefix
|
|
||||||
else:
|
|
||||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
|
||||||
"please check and correct it and then continue.".format(prefix))
|
|
||||||
if directory is not None:
|
|
||||||
self._directory = _make_directory(directory)
|
|
||||||
else:
|
|
||||||
self._directory = _cur_dir
|
|
||||||
if ckconfig is None:
|
|
||||||
self._config = CheckpointConfig()
|
|
||||||
else:
|
|
||||||
if not isinstance(ckconfig, CheckpointConfig):
|
|
||||||
raise TypeError("ckconfig should be CheckpointConfig type.")
|
|
||||||
self._config = ckconfig
|
|
||||||
# get existing checkpoint files
|
|
||||||
self._manager = CheckpointManager()
|
|
||||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
|
||||||
self._graph_saved = False
|
|
||||||
self._need_flush_from_cache = True
|
|
||||||
self.device_num = device_num
|
self.device_num = device_num
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def step_end(self, run_context):
|
|
||||||
"""
|
|
||||||
Save the checkpoint at the end of step.
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Context of the train running.
|
|
||||||
"""
|
|
||||||
if _is_role_pserver():
|
|
||||||
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
|
|
||||||
cb_params = run_context.original_args()
|
|
||||||
_make_directory(self._directory)
|
|
||||||
# save graph (only once)
|
|
||||||
if not self._graph_saved:
|
|
||||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
|
||||||
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
|
|
||||||
os.remove(graph_file_name)
|
|
||||||
_save_graph(cb_params.train_network, graph_file_name)
|
|
||||||
self._graph_saved = True
|
|
||||||
thread_list = threading.enumerate()
|
|
||||||
for thread in thread_list:
|
|
||||||
if thread.getName() == "asyn_save_ckpt":
|
|
||||||
thread.join()
|
|
||||||
self._save_ckpt(cb_params)
|
|
||||||
|
|
||||||
def end(self, run_context):
|
|
||||||
"""
|
|
||||||
Save the last checkpoint after training finished.
|
|
||||||
Args:
|
|
||||||
run_context (RunContext): Context of the train running.
|
|
||||||
"""
|
|
||||||
cb_params = run_context.original_args()
|
|
||||||
_to_save_last_ckpt = True
|
|
||||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
|
||||||
thread_list = threading.enumerate()
|
|
||||||
for thread in thread_list:
|
|
||||||
if thread.getName() == "asyn_save_ckpt":
|
|
||||||
thread.join()
|
|
||||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
|
||||||
destroy_allgather_cell()
|
|
||||||
|
|
||||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
|
||||||
"""Check whether save checkpoint files or not."""
|
|
||||||
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
|
|
||||||
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
|
|
||||||
or force_to_save is True:
|
|
||||||
return True
|
|
||||||
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
|
|
||||||
self._cur_time = time.time()
|
|
||||||
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
|
|
||||||
self._last_time = self._cur_time
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _save_ckpt(self, cb_params, force_to_save=False):
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
||||||
"""Save checkpoint files."""
|
super()._save_ckpt(cb_params, force_to_save)
|
||||||
if cb_params.cur_step_num == self._last_triggered_step:
|
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
|
||||||
return
|
import moxing as mox
|
||||||
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
|
||||||
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
||||||
if save_ckpt:
|
|
||||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
|
||||||
+ str(step_num_in_epoch) + ".ckpt"
|
|
||||||
# update checkpoint file list.
|
|
||||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
|
||||||
# keep checkpoint files number equal max number.
|
|
||||||
if self._config.keep_checkpoint_max and \
|
|
||||||
0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
|
|
||||||
self._manager.remove_oldest_ckpoint_file()
|
|
||||||
elif self._config.keep_checkpoint_per_n_minutes and \
|
|
||||||
self._config.keep_checkpoint_per_n_minutes > 0:
|
|
||||||
self._cur_time_for_keep = time.time()
|
|
||||||
if (self._cur_time_for_keep - self._last_time_for_keep) \
|
|
||||||
< self._config.keep_checkpoint_per_n_minutes * 60:
|
|
||||||
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
|
|
||||||
self._cur_time_for_keep)
|
|
||||||
# generate the new checkpoint file and rename it.
|
|
||||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
|
||||||
self._last_time_for_keep = time.time()
|
|
||||||
self._last_triggered_step = cb_params.cur_step_num
|
|
||||||
if context.get_context("enable_ge"):
|
|
||||||
set_cur_net(cb_params.train_network)
|
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
|
||||||
network = self._config.saved_network if self._config.saved_network is not None \
|
|
||||||
else cb_params.train_network
|
|
||||||
save_checkpoint(network, cur_file, self._config.integrated_save,
|
|
||||||
self._config.async_save)
|
|
||||||
self._latest_ckpt_file_name = cur_file
|
|
||||||
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
|
|
||||||
import moxing as mox
|
|
||||||
mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
|
|
||||||
|
|
||||||
def _flush_from_cache(self, cb_params):
|
|
||||||
"""Flush cache data to host if tensor is cache enable."""
|
|
||||||
has_cache_params = False
|
|
||||||
params = cb_params.train_network.get_parameters()
|
|
||||||
for param in params:
|
|
||||||
if param.cache_enable:
|
|
||||||
has_cache_params = True
|
|
||||||
Tensor(param).flush_from_cache()
|
|
||||||
if not has_cache_params:
|
|
||||||
self._need_flush_from_cache = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def latest_ckpt_file_name(self):
|
|
||||||
"""Return the latest checkpoint path and file name."""
|
|
||||||
return self._latest_ckpt_file_name
|
|
||||||
|
|
||||||
|
|
||||||
class LossCallBack(Callback):
|
class LossCallBack(Callback):
|
||||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.dataset import create_dataset_train
|
from src.dataset import create_dataset_train
|
||||||
from src.lr_generator import get_lr
|
from src.lr_generator import get_lr
|
||||||
from src.network import NTS_NET, WithLossCell, LossCallBack, ModelCheckpoint
|
from src.network import NTS_NET, WithLossCell, LossCallBack, NtsnetModelCheckpoint
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='ntsnet train running')
|
parser = argparse.ArgumentParser(description='ntsnet train running')
|
||||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||||
|
@ -113,8 +113,9 @@ if __name__ == '__main__':
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/")
|
save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/")
|
||||||
|
|
||||||
ckpoint_cb = ModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig,
|
ckpoint_cb = NtsnetModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig,
|
||||||
device_num=device_num, device_id=device_id, args=args, run_modelart=run_modelart)
|
device_num=device_num, device_id=device_id, args=args,
|
||||||
|
run_modelart=run_modelart)
|
||||||
cb += [ckpoint_cb]
|
cb += [ckpoint_cb]
|
||||||
|
|
||||||
model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False)
|
model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False)
|
||||||
|
|
Loading…
Reference in New Issue