forked from mindspore-Ecosystem/mindspore
!21338 fix doc problem
Merge pull request !21338 from JichenZhao/master
This commit is contained in:
commit
d41e3796c6
|
@ -124,7 +124,7 @@ weight_decay: 0.00001
|
|||
epoch_size: 20
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 20
|
||||
keep_checkpoint_max: 5
|
||||
save_checkpoint_path: "./"
|
||||
|
||||
# Number of threads used to process the dataset in parallel
|
||||
|
|
|
@ -122,10 +122,10 @@ batch_size: 2
|
|||
loss_scale: 256
|
||||
momentum: 0.91
|
||||
weight_decay: 0.00001
|
||||
epoch_size: 5
|
||||
epoch_size: 20
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 20
|
||||
keep_checkpoint_max: 5
|
||||
save_checkpoint_path: "./"
|
||||
|
||||
# Number of threads used to process the dataset in parallel
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.train.model import Model, ParallelMode
|
|||
from mindspore import dtype as mstype
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.communication import management as MutiDev
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
|
||||
|
@ -140,22 +141,26 @@ if __name__ == "__main__":
|
|||
|
||||
model = Model(train_net, optimizer=optimizer)
|
||||
|
||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=60, keep_checkpoint_max=20)
|
||||
save_checkpoint_steps=60, keep_checkpoint_max=5)
|
||||
if args.modelarts:
|
||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||
directory='/cache/train_output/')
|
||||
cb.append(ckpt_cb)
|
||||
else:
|
||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||
directory=args.train_url)
|
||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||
loss_cb = LossMonitor()
|
||||
cb = [ckpt_cb, time_cb, loss_cb]
|
||||
if args.device_id == 0 or args.device_num == 1:
|
||||
model.train(train_epoch, train_dataset,
|
||||
callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
model.train(train_epoch, train_dataset, dataset_sink_mode=True)
|
||||
if args.device_num == 8 and MutiDev.get_rank() % 8 == 0:
|
||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||
directory=args.train_url)
|
||||
cb.append(ckpt_cb)
|
||||
if args.device_num == 1:
|
||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||
directory=args.train_url)
|
||||
cb.append(ckpt_cb)
|
||||
|
||||
model.train(train_epoch, train_dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
if args.modelarts:
|
||||
mox.file.copy_parallel(
|
||||
src_url='/cache/train_output', dst_url=args.train_url)
|
||||
|
|
|
@ -44,7 +44,7 @@ parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
|
|||
|
||||
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
|
||||
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
|
||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Location of training outputs.')
|
||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue