!21338 fix doc problem

Merge pull request !21338 from JichenZhao/master
This commit is contained in:
i-robot 2021-08-06 03:09:55 +00:00 committed by Gitee
commit d41e3796c6
4 changed files with 20 additions and 15 deletions

View File

@ -124,7 +124,7 @@ weight_decay: 0.00001
epoch_size: 20 epoch_size: 20
save_checkpoint: True save_checkpoint: True
save_checkpoint_epochs: 1 save_checkpoint_epochs: 1
keep_checkpoint_max: 20 keep_checkpoint_max: 5
save_checkpoint_path: "./" save_checkpoint_path: "./"
# Number of threads used to process the dataset in parallel # Number of threads used to process the dataset in parallel

View File

@ -122,10 +122,10 @@ batch_size: 2
loss_scale: 256 loss_scale: 256
momentum: 0.91 momentum: 0.91
weight_decay: 0.00001 weight_decay: 0.00001
epoch_size: 5 epoch_size: 20
save_checkpoint: True save_checkpoint: True
save_checkpoint_epochs: 1 save_checkpoint_epochs: 1
keep_checkpoint_max: 20 keep_checkpoint_max: 5
save_checkpoint_path: "./" save_checkpoint_path: "./"
# Number of threads used to process the dataset in parallel # Number of threads used to process the dataset in parallel

View File

@ -26,6 +26,7 @@ from mindspore.train.model import Model, ParallelMode
from mindspore import dtype as mstype from mindspore import dtype as mstype
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.communication import management as MutiDev
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
@ -140,22 +141,26 @@ if __name__ == "__main__":
model = Model(train_net, optimizer=optimizer) model = Model(train_net, optimizer=optimizer)
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig( config_ck = CheckpointConfig(
save_checkpoint_steps=60, keep_checkpoint_max=20) save_checkpoint_steps=60, keep_checkpoint_max=5)
if args.modelarts: if args.modelarts:
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck, ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
directory='/cache/train_output/') directory='/cache/train_output/')
cb.append(ckpt_cb)
else: else:
if args.device_num == 8 and MutiDev.get_rank() % 8 == 0:
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck, ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
directory=args.train_url) directory=args.train_url)
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size()) cb.append(ckpt_cb)
loss_cb = LossMonitor() if args.device_num == 1:
cb = [ckpt_cb, time_cb, loss_cb] ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
if args.device_id == 0 or args.device_num == 1: directory=args.train_url)
model.train(train_epoch, train_dataset, cb.append(ckpt_cb)
callbacks=cb, dataset_sink_mode=True)
else: model.train(train_epoch, train_dataset, callbacks=cb, dataset_sink_mode=True)
model.train(train_epoch, train_dataset, dataset_sink_mode=True)
if args.modelarts: if args.modelarts:
mox.file.copy_parallel( mox.file.copy_parallel(
src_url='/cache/train_output', dst_url=args.train_url) src_url='/cache/train_output', dst_url=args.train_url)

View File

@ -44,7 +44,7 @@ parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.") parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.") parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.') parser.add_argument('--train_url', type=str, default=None, help='Location of training outputs.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.") parser.add_argument("--config", type=str, required=True, help="model config json file path.")
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.") parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
args = parser.parse_args() args = parser.parse_args()