forked from mindspore-Ecosystem/mindspore
Add load&save ckpt path for distribute training
This commit is contained in:
parent
6e7a38ac67
commit
9d74cfd312
|
@ -68,7 +68,8 @@ def run_pretrain():
|
|||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
||||
"default is 1000.")
|
||||
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
|
||||
|
@ -81,7 +82,7 @@ def run_pretrain():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
|
||||
ckpt_save_dir = args_opt.checkpoint_path
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init('hccl')
|
||||
|
@ -91,7 +92,7 @@ def run_pretrain():
|
|||
D.init('nccl')
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
|
@ -150,8 +151,8 @@ def run_pretrain():
|
|||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
if args_opt.load_checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
load_param_into_net(netwithloss, param_dict)
|
||||
|
||||
if args_opt.enable_lossscale == "true":
|
||||
|
|
|
@ -64,7 +64,7 @@ do
|
|||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=100 \
|
||||
--checkpoint_path="" \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
|
|
|
@ -36,7 +36,7 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
|
|||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1 \
|
||||
--checkpoint_path="" \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
|
|
|
@ -38,7 +38,7 @@ python run_pretrain.py \
|
|||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1 \
|
||||
--checkpoint_path="" \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
|
|
|
@ -40,7 +40,8 @@ python run_pretrain.py \
|
|||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1 \
|
||||
--checkpoint_path="" \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
|
|
Loading…
Reference in New Issue