diff --git a/model_zoo/official/nlp/bert/mindspore_hub_conf.py b/model_zoo/official/nlp/bert/mindspore_hub_conf.py index 012ac95017..8378b3f34e 100644 --- a/model_zoo/official/nlp/bert/mindspore_hub_conf.py +++ b/model_zoo/official/nlp/bert/mindspore_hub_conf.py @@ -67,11 +67,13 @@ def create_network(name, *args, **kwargs): bert_net_cfg_base.batch_size = kwargs["batch_size"] if "seq_length" in kwargs: bert_net_cfg_base.seq_length = kwargs["seq_length"] - return BertModel(bert_net_cfg_base, *args) + is_training = kwargs.get("is_training", default=False) + return BertModel(bert_net_cfg_base, is_training, *args) if name == 'bert_nezha': if "batch_size" in kwargs: bert_net_cfg_nezha.batch_size = kwargs["batch_size"] if "seq_length" in kwargs: bert_net_cfg_nezha.seq_length = kwargs["seq_length"] - return BertModel(bert_net_cfg_nezha, *args) + is_training = kwargs.get("is_training", default=False) + return BertModel(bert_net_cfg_nezha, is_training, *args) raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/nlp/tinybert/scripts/run_distributed_gd_gpu.sh b/model_zoo/official/nlp/tinybert/scripts/run_distributed_gd_gpu.sh index ab7d2046ef..d09f49760d 100644 --- a/model_zoo/official/nlp/tinybert/scripts/run_distributed_gd_gpu.sh +++ b/model_zoo/official/nlp/tinybert/scripts/run_distributed_gd_gpu.sh @@ -38,5 +38,5 @@ mpirun --allow-run-as-root -n $RANK_SIZE \ --data_dir=$DATA_DIR \ --schema_dir=$SCHEMA_DIR \ --dataset_type="tfrecord" \ - --enable_data_sink=False \ + --enable_data_sink="false" \ --load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &