forked from mindspore-Ecosystem/mindspore
only save ckpt in rank0 for Transformer
This commit is contained in:
parent
3d377c51b9
commit
3ce29513db
|
@ -147,6 +147,7 @@ def run_transformer_train():
|
||||||
|
|
||||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
|
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
|
||||||
if args.enable_save_ckpt == "true":
|
if args.enable_save_ckpt == "true":
|
||||||
|
if device_num == 1 or (device_num > 1 and rank_id == 0):
|
||||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
||||||
keep_checkpoint_max=args.save_checkpoint_num)
|
keep_checkpoint_max=args.save_checkpoint_num)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
|
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
|
||||||
|
|
Loading…
Reference in New Issue