From 32846d791bc91fa1bf20ad01af0d3d3fab80dfed Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Tue, 11 Aug 2020 14:40:06 +0800 Subject: [PATCH] bugfix bert script --- model_zoo/official/nlp/bert/run_pretrain.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index df0c2c433b6..291a7844417 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -51,7 +51,7 @@ def run_pretrain(): parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") - parser.add_argument("--save_checkpoint_path", type=str, default=None, help="Save checkpoint path") + parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " "default is 1000.") @@ -145,7 +145,8 @@ def run_pretrain(): if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0: config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, keep_checkpoint_max=args_opt.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) + ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', + directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck) callback.append(ckpoint_cb) if args_opt.load_checkpoint_path: