forked from mindspore-Ecosystem/mindspore
!21338 fix doc problem
Merge pull request !21338 from JichenZhao/master
This commit is contained in:
commit
d41e3796c6
|
@ -124,7 +124,7 @@ weight_decay: 0.00001
|
||||||
epoch_size: 20
|
epoch_size: 20
|
||||||
save_checkpoint: True
|
save_checkpoint: True
|
||||||
save_checkpoint_epochs: 1
|
save_checkpoint_epochs: 1
|
||||||
keep_checkpoint_max: 20
|
keep_checkpoint_max: 5
|
||||||
save_checkpoint_path: "./"
|
save_checkpoint_path: "./"
|
||||||
|
|
||||||
# Number of threads used to process the dataset in parallel
|
# Number of threads used to process the dataset in parallel
|
||||||
|
|
|
@ -122,10 +122,10 @@ batch_size: 2
|
||||||
loss_scale: 256
|
loss_scale: 256
|
||||||
momentum: 0.91
|
momentum: 0.91
|
||||||
weight_decay: 0.00001
|
weight_decay: 0.00001
|
||||||
epoch_size: 5
|
epoch_size: 20
|
||||||
save_checkpoint: True
|
save_checkpoint: True
|
||||||
save_checkpoint_epochs: 1
|
save_checkpoint_epochs: 1
|
||||||
keep_checkpoint_max: 20
|
keep_checkpoint_max: 5
|
||||||
save_checkpoint_path: "./"
|
save_checkpoint_path: "./"
|
||||||
|
|
||||||
# Number of threads used to process the dataset in parallel
|
# Number of threads used to process the dataset in parallel
|
||||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.train.model import Model, ParallelMode
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.communication import management as MutiDev
|
||||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||||
from mindspore.parallel import set_algo_parameters
|
from mindspore.parallel import set_algo_parameters
|
||||||
|
|
||||||
|
@ -140,22 +141,26 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
model = Model(train_net, optimizer=optimizer)
|
model = Model(train_net, optimizer=optimizer)
|
||||||
|
|
||||||
|
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
cb = [time_cb, loss_cb]
|
||||||
config_ck = CheckpointConfig(
|
config_ck = CheckpointConfig(
|
||||||
save_checkpoint_steps=60, keep_checkpoint_max=20)
|
save_checkpoint_steps=60, keep_checkpoint_max=5)
|
||||||
if args.modelarts:
|
if args.modelarts:
|
||||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||||
directory='/cache/train_output/')
|
directory='/cache/train_output/')
|
||||||
|
cb.append(ckpt_cb)
|
||||||
else:
|
else:
|
||||||
|
if args.device_num == 8 and MutiDev.get_rank() % 8 == 0:
|
||||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||||
directory=args.train_url)
|
directory=args.train_url)
|
||||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
cb.append(ckpt_cb)
|
||||||
loss_cb = LossMonitor()
|
if args.device_num == 1:
|
||||||
cb = [ckpt_cb, time_cb, loss_cb]
|
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||||
if args.device_id == 0 or args.device_num == 1:
|
directory=args.train_url)
|
||||||
model.train(train_epoch, train_dataset,
|
cb.append(ckpt_cb)
|
||||||
callbacks=cb, dataset_sink_mode=True)
|
|
||||||
else:
|
model.train(train_epoch, train_dataset, callbacks=cb, dataset_sink_mode=True)
|
||||||
model.train(train_epoch, train_dataset, dataset_sink_mode=True)
|
|
||||||
if args.modelarts:
|
if args.modelarts:
|
||||||
mox.file.copy_parallel(
|
mox.file.copy_parallel(
|
||||||
src_url='/cache/train_output', dst_url=args.train_url)
|
src_url='/cache/train_output', dst_url=args.train_url)
|
||||||
|
|
|
@ -44,7 +44,7 @@ parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
|
||||||
|
|
||||||
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
|
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
|
||||||
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
|
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
|
||||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
parser.add_argument('--train_url', type=str, default=None, help='Location of training outputs.')
|
||||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||||
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue