!21338 fix doc problem

Merge pull request !21338 from JichenZhao/master
This commit is contained in:
i-robot 2021-08-06 03:09:55 +00:00 committed by Gitee
commit d41e3796c6
4 changed files with 20 additions and 15 deletions

View File

@ -124,7 +124,7 @@ weight_decay: 0.00001
epoch_size: 20
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 20
keep_checkpoint_max: 5
save_checkpoint_path: "./"
# Number of threads used to process the dataset in parallel

View File

@ -122,10 +122,10 @@ batch_size: 2
loss_scale: 256
momentum: 0.91
weight_decay: 0.00001
epoch_size: 5
epoch_size: 20
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 20
keep_checkpoint_max: 5
save_checkpoint_path: "./"
# Number of threads used to process the dataset in parallel

View File

@ -26,6 +26,7 @@ from mindspore.train.model import Model, ParallelMode
from mindspore import dtype as mstype
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.communication.management import init
from mindspore.communication import management as MutiDev
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_algo_parameters
@ -140,22 +141,26 @@ if __name__ == "__main__":
model = Model(train_net, optimizer=optimizer)
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig(
save_checkpoint_steps=60, keep_checkpoint_max=20)
save_checkpoint_steps=60, keep_checkpoint_max=5)
if args.modelarts:
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
directory='/cache/train_output/')
cb.append(ckpt_cb)
else:
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
directory=args.train_url)
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
loss_cb = LossMonitor()
cb = [ckpt_cb, time_cb, loss_cb]
if args.device_id == 0 or args.device_num == 1:
model.train(train_epoch, train_dataset,
callbacks=cb, dataset_sink_mode=True)
else:
model.train(train_epoch, train_dataset, dataset_sink_mode=True)
if args.device_num == 8 and MutiDev.get_rank() % 8 == 0:
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
directory=args.train_url)
cb.append(ckpt_cb)
if args.device_num == 1:
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
directory=args.train_url)
cb.append(ckpt_cb)
model.train(train_epoch, train_dataset, callbacks=cb, dataset_sink_mode=True)
if args.modelarts:
mox.file.copy_parallel(
src_url='/cache/train_output', dst_url=args.train_url)

View File

@ -44,7 +44,7 @@ parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
parser.add_argument('--train_url', type=str, default=None, help='Location of training outputs.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
args = parser.parse_args()