only save ckpt in rank0 for Transformer
This commit is contained in:
parent
3d377c51b9
commit
3ce29513db
|
@ -147,6 +147,7 @@ def run_transformer_train():
|
|||
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
|
||||
if args.enable_save_ckpt == "true":
|
||||
if device_num == 1 or (device_num > 1 and rank_id == 0):
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
|
||||
|
|
Loading…
Reference in New Issue