diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md index ae80e442a14..48423e32f70 100644 --- a/model_zoo/official/nlp/bert/README.md +++ b/model_zoo/official/nlp/bert/README.md @@ -312,6 +312,7 @@ Parameters for training and evaluation can be set in file `config.py` and `finet ``` config for lossscale and etc. bert_network version of BERT model: base | nezha, default is base + batch_size batch size of input dataset: N, default is 16 loss_scale_value initial value of loss scale: N, default is 2^32 scale_factor factor used to update loss scale: N, default is 2 scale_window steps for once updatation of loss scale: N, default is 1000 @@ -321,7 +322,6 @@ config for lossscale and etc. ### Parameters: ``` Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): - batch_size batch size of input dataset: N, default is 16 seq_length length of input sequence: N, default is 128 vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 hidden_size size of bert encoder layers: N, default is 768 @@ -335,8 +335,6 @@ Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): type_vocab_size size of token type vocab: N, default is 16 initializer_range initialization value of TruncatedNormal: Q, default is 0.02 use_relative_positions use relative positions or not: True | False, default is False - input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True - token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32 compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16 diff --git a/model_zoo/official/nlp/bert/mindspore_hub_conf.py b/model_zoo/official/nlp/bert/mindspore_hub_conf.py index 8378b3f34eb..ea4b63ac20f 100644 --- a/model_zoo/official/nlp/bert/mindspore_hub_conf.py +++ b/model_zoo/official/nlp/bert/mindspore_hub_conf.py @@ -19,7 +19,6 @@ from src.bert_model import BertModel from src.bert_model import BertConfig import mindspore.common.dtype as mstype bert_net_cfg_base = BertConfig( - batch_size=32, seq_length=128, vocab_size=21128, hidden_size=768, @@ -33,13 +32,10 @@ bert_net_cfg_base = BertConfig( type_vocab_size=2, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16 ) bert_net_cfg_nezha = BertConfig( - batch_size=32, seq_length=128, vocab_size=21128, hidden_size=1024, @@ -53,8 +49,6 @@ bert_net_cfg_nezha = BertConfig( type_vocab_size=2, initializer_range=0.02, use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16 ) @@ -63,15 +57,11 @@ def create_network(name, *args, **kwargs): Create bert network for base and nezha. ''' if name == 'bert_base': - if "batch_size" in kwargs: - bert_net_cfg_base.batch_size = kwargs["batch_size"] if "seq_length" in kwargs: bert_net_cfg_base.seq_length = kwargs["seq_length"] is_training = kwargs.get("is_training", default=False) return BertModel(bert_net_cfg_base, is_training, *args) if name == 'bert_nezha': - if "batch_size" in kwargs: - bert_net_cfg_nezha.batch_size = kwargs["batch_size"] if "seq_length" in kwargs: bert_net_cfg_nezha.seq_length = kwargs["seq_length"] is_training = kwargs.get("is_training", default=False) diff --git a/model_zoo/official/nlp/bert/pretrain_eval.py b/model_zoo/official/nlp/bert/pretrain_eval.py index d2924854d79..b8cb760dd01 100644 --- a/model_zoo/official/nlp/bert/pretrain_eval.py +++ b/model_zoo/official/nlp/bert/pretrain_eval.py @@ -131,7 +131,7 @@ def bert_predict(): ''' devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) - dataset = get_enwiki_512_dataset(bert_net_cfg.batch_size, 1) + dataset = get_enwiki_512_dataset(cfg.batch_size, 1) net_for_pretraining = BertPretrainEva(bert_net_cfg) net_for_pretraining.set_train(False) param_dict = load_checkpoint(cfg.finetune_ckpt) diff --git a/model_zoo/official/nlp/bert/run_classifier.py b/model_zoo/official/nlp/bert/run_classifier.py index 236e6947bbd..2cbf637f230 100644 --- a/model_zoo/official/nlp/bert/run_classifier.py +++ b/model_zoo/official/nlp/bert/run_classifier.py @@ -188,7 +188,7 @@ def run_classifier(): assessment_method=assessment_method) if args_opt.do_train.lower() == "true": - ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + ds = create_classification_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, schema_file_path=args_opt.schema_file_path, @@ -204,7 +204,7 @@ def run_classifier(): ds.get_dataset_size(), epoch_num, "classifier") if args_opt.do_eval.lower() == "true": - ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + ds = create_classification_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, schema_file_path=args_opt.schema_file_path, diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py index 33d272c373e..a2ebdc5b3ed 100644 --- a/model_zoo/official/nlp/bert/run_ner.py +++ b/model_zoo/official/nlp/bert/run_ner.py @@ -104,9 +104,9 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_meth if load_checkpoint_path == "": raise ValueError("Finetune model missed, evaluation task must load finetune model!") if assessment_method == "clue_benchmark": - bert_net_cfg.batch_size = 1 - net_for_pretraining = network(bert_net_cfg, False, num_class, use_crf=(use_crf.lower() == "true"), - tag_to_index=tag_to_index) + optimizer_cfg.batch_size = 1 + net_for_pretraining = network(bert_net_cfg, optimizer_cfg.batch_size, False, num_class, + use_crf=(use_crf.lower() == "true"), tag_to_index=tag_to_index) net_for_pretraining.set_train(False) param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(net_for_pretraining, param_dict) @@ -211,11 +211,11 @@ def run_ner(): number_labels = len(tag_to_index) else: number_labels = args_opt.num_class - netwithloss = BertNER(bert_net_cfg, True, num_labels=number_labels, + netwithloss = BertNER(bert_net_cfg, optimizer_cfg.batch_size, True, num_labels=number_labels, use_crf=(args_opt.use_crf.lower() == "true"), tag_to_index=tag_to_index, dropout_prob=0.1) if args_opt.do_train.lower() == "true": - ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + ds = create_ner_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, schema_file_path=args_opt.schema_file_path, do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 538b73779b9..29532228ed8 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -108,7 +108,7 @@ def run_pretrain(): if args_opt.accumulation_steps > 1: logger.info("accumulation steps: {}".format(args_opt.accumulation_steps)) - logger.info("global batch size: {}".format(bert_net_cfg.batch_size * args_opt.accumulation_steps)) + logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps)) if args_opt.enable_data_sink == "true": args_opt.data_sink_steps *= args_opt.accumulation_steps logger.info("data sink steps: {}".format(args_opt.data_sink_steps)) diff --git a/model_zoo/official/nlp/bert/run_squad.py b/model_zoo/official/nlp/bert/run_squad.py index bd45ffcc1fc..78cdb4e7e49 100644 --- a/model_zoo/official/nlp/bert/run_squad.py +++ b/model_zoo/official/nlp/bert/run_squad.py @@ -123,7 +123,7 @@ def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", start = logits[1].asnumpy() end = logits[2].asnumpy() - for i in range(bert_net_cfg.batch_size): + for i in range(optimizer_cfg.batch_size): unique_id = int(ids[i]) start_logits = [float(x) for x in start[i].flat] end_logits = [float(x) for x in end[i].flat] @@ -193,7 +193,7 @@ def run_squad(): netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) if args_opt.do_train.lower() == "true": - ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + ds = create_squad_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1, data_file_path=args_opt.train_data_file_path, schema_file_path=args_opt.schema_file_path, do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) @@ -207,7 +207,7 @@ def run_squad(): ds.get_dataset_size(), epoch_num, "squad") if args_opt.do_eval.lower() == "true": - ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1, + ds = create_squad_dataset(batch_size=optimizer_cfg.batch_size, repeat_count=1, data_file_path=args_opt.eval_data_file_path, schema_file_path=args_opt.schema_file_path, is_training=False, do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) diff --git a/model_zoo/official/nlp/bert/src/bert_for_finetune.py b/model_zoo/official/nlp/bert/src/bert_for_finetune.py index fdfdd6e4068..d31adbb02b6 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_finetune.py +++ b/model_zoo/official/nlp/bert/src/bert_for_finetune.py @@ -274,15 +274,15 @@ class BertNER(nn.Cell): """ Train interface for sequence labeling finetuning task. """ - def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, - use_one_hot_embeddings=False): + def __init__(self, config, batch_size, is_training, num_labels=11, use_crf=False, + tag_to_index=None, dropout_prob=0.0, use_one_hot_embeddings=False): super(BertNER, self).__init__() self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) if use_crf: if not tag_to_index: raise Exception("The dict for tag-index mapping should be provided for CRF.") from src.CRF import CRF - self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) + self.loss = CRF(tag_to_index, batch_size, config.seq_length, is_training) else: self.loss = CrossEntropyCalculation(is_training) self.num_labels = num_labels diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 64f6c96a721..65070cc8c1b 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -92,9 +92,8 @@ class GetMaskedLMOutput(nn.Cell): self.matmul = P.MatMul(transpose_b=True) self.log_softmax = nn.LogSoftmax(axis=-1) self.shape_flat_offsets = (-1, 1) - self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32)) self.last_idx = (-1,) - self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) + self.shape_flat_sequence_tensor = (-1, self.width) self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32)) self.cast = P.Cast() self.compute_type = config.compute_type @@ -105,8 +104,8 @@ class GetMaskedLMOutput(nn.Cell): output_weights, positions): """Get output log_probs""" - flat_offsets = self.reshape( - self.rng * self.seq_length_tensor, self.shape_flat_offsets) + rng = F.tuple_to_array(F.make_range(P.Shape()(input_tensor)[0])) + flat_offsets = self.reshape(rng * self.seq_length_tensor, self.shape_flat_offsets) flat_position = self.reshape(positions + flat_offsets, self.last_idx) flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) diff --git a/model_zoo/official/nlp/bert/src/bert_model.py b/model_zoo/official/nlp/bert/src/bert_model.py index c305ddc1a2c..ea973030a75 100644 --- a/model_zoo/official/nlp/bert/src/bert_model.py +++ b/model_zoo/official/nlp/bert/src/bert_model.py @@ -32,7 +32,6 @@ class BertConfig: Configuration for `BertModel`. Args: - batch_size (int): Batch size of input dataset. seq_length (int): Length of input sequence. Default: 128. vocab_size (int): The shape of each embedding vector. Default: 32000. hidden_size (int): Size of the bert encoder layers. Default: 768. @@ -52,15 +51,10 @@ class BertConfig: type_vocab_size (int): Size of token type vocab. Default: 16. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from - dataset. Default: True. - token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded - from dataset. Default: True. dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. """ def __init__(self, - batch_size, seq_length=128, vocab_size=32000, hidden_size=768, @@ -74,11 +68,8 @@ class BertConfig: type_vocab_size=16, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float32): - self.batch_size = batch_size self.seq_length = seq_length self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -91,8 +82,6 @@ class BertConfig: self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.input_mask_from_dataset = input_mask_from_dataset - self.token_type_ids_from_dataset = token_type_ids_from_dataset self.use_relative_positions = use_relative_positions self.dtype = dtype self.compute_type = compute_type @@ -385,7 +374,6 @@ class BertAttention(nn.Cell): Apply multi-headed attention from "from_tensor" to "to_tensor". Args: - batch_size (int): Batch size of input datasets. from_tensor_width (int): Size of last dim of from_tensor. to_tensor_width (int): Size of last dim of to_tensor. from_seq_length (int): Length of from_tensor sequence. @@ -406,7 +394,6 @@ class BertAttention(nn.Cell): compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. """ def __init__(self, - batch_size, from_tensor_width, to_tensor_width, from_seq_length, @@ -425,7 +412,6 @@ class BertAttention(nn.Cell): compute_type=mstype.float32): super(BertAttention, self).__init__() - self.batch_size = batch_size self.from_seq_length = from_seq_length self.to_seq_length = to_seq_length self.num_attention_heads = num_attention_heads @@ -452,9 +438,8 @@ class BertAttention(nn.Cell): activation=value_act, weight_init=weight).to_float(compute_type) - self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) - self.shape_to = ( - batch_size, to_seq_length, num_attention_heads, size_per_head) + self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head) self.matmul_trans_b = P.BatchMatMul(transpose_b=True) self.multiply = P.Mul() @@ -463,7 +448,6 @@ class BertAttention(nn.Cell): self.trans_shape_relative = (2, 0, 1, 3) self.trans_shape_position = (1, 2, 0, 3) self.multiply_data = -10000.0 - self.batch_num = batch_size * num_attention_heads self.matmul = P.BatchMatMul() self.softmax = nn.Softmax() @@ -476,9 +460,9 @@ class BertAttention(nn.Cell): self.cast = P.Cast() self.get_dtype = P.DType() if do_return_2d_tensor: - self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + self.shape_return = (-1, num_attention_heads * size_per_head) else: - self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head) self.cast_compute_type = SaturateCast(dst_type=compute_type) if self.use_relative_positions: @@ -514,7 +498,7 @@ class BertAttention(nn.Cell): # query_layer_r is [F, B * N, H] query_layer_r = self.reshape(query_layer_t, (self.from_seq_length, - self.batch_num, + -1, self.size_per_head)) # key_position_scores is [F, B * N, F|T] key_position_scores = self.matmul_trans_b(query_layer_r, @@ -522,7 +506,7 @@ class BertAttention(nn.Cell): # key_position_scores_r is [F, B, N, F|T] key_position_scores_r = self.reshape(key_position_scores, (self.from_seq_length, - self.batch_size, + -1, self.num_attention_heads, self.from_seq_length)) # key_position_scores_r_t is [B, N, F, F|T] @@ -585,7 +569,6 @@ class BertSelfAttention(nn.Cell): Apply self-attention. Args: - batch_size (int): Batch size of input dataset. seq_length (int): Length of input sequence. hidden_size (int): Size of the bert encoder layers. num_attention_heads (int): Number of attention heads. Default: 12. @@ -598,7 +581,6 @@ class BertSelfAttention(nn.Cell): compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. """ def __init__(self, - batch_size, seq_length, hidden_size, num_attention_heads=12, @@ -616,7 +598,6 @@ class BertSelfAttention(nn.Cell): self.size_per_head = int(hidden_size / num_attention_heads) self.attention = BertAttention( - batch_size=batch_size, from_tensor_width=hidden_size, to_tensor_width=hidden_size, from_seq_length=seq_length, @@ -651,7 +632,6 @@ class BertEncoderCell(nn.Cell): Encoder cells used in BertTransformer. Args: - batch_size (int): Batch size of input dataset. hidden_size (int): Size of the bert encoder layers. Default: 768. seq_length (int): Length of input sequence. Default: 512. num_attention_heads (int): Number of attention heads. Default: 12. @@ -666,7 +646,6 @@ class BertEncoderCell(nn.Cell): compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. """ def __init__(self, - batch_size, hidden_size=768, seq_length=512, num_attention_heads=12, @@ -680,7 +659,6 @@ class BertEncoderCell(nn.Cell): compute_type=mstype.float32): super(BertEncoderCell, self).__init__() self.attention = BertSelfAttention( - batch_size=batch_size, hidden_size=hidden_size, seq_length=seq_length, num_attention_heads=num_attention_heads, @@ -715,7 +693,6 @@ class BertTransformer(nn.Cell): Multi-layer bert transformer. Args: - batch_size (int): Batch size of input dataset. hidden_size (int): Size of the encoder layers. seq_length (int): Length of input sequence. num_hidden_layers (int): Number of hidden layers in encoder cells. @@ -732,7 +709,6 @@ class BertTransformer(nn.Cell): return_all_encoders (bool): Specifies whether to return all encoders. Default: False. """ def __init__(self, - batch_size, hidden_size, seq_length, num_hidden_layers, @@ -751,8 +727,7 @@ class BertTransformer(nn.Cell): layers = [] for _ in range(num_hidden_layers): - layer = BertEncoderCell(batch_size=batch_size, - hidden_size=hidden_size, + layer = BertEncoderCell(hidden_size=hidden_size, seq_length=seq_length, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, @@ -769,7 +744,7 @@ class BertTransformer(nn.Cell): self.reshape = P.Reshape() self.shape = (-1, hidden_size) - self.out_shape = (batch_size, seq_length, hidden_size) + self.out_shape = (-1, seq_length, hidden_size) def construct(self, input_tensor, attention_mask): """Multi-layer bert transformer.""" @@ -799,24 +774,12 @@ class CreateAttentionMaskFromInputMask(nn.Cell): """ def __init__(self, config): super(CreateAttentionMaskFromInputMask, self).__init__() - self.input_mask_from_dataset = config.input_mask_from_dataset self.input_mask = None - - if not self.input_mask_from_dataset: - self.input_mask = initializer( - "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() - self.cast = P.Cast() self.reshape = P.Reshape() - self.shape = (config.batch_size, 1, config.seq_length) - self.broadcast_ones = initializer( - "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() - self.batch_matmul = P.BatchMatMul() + self.shape = (-1, 1, config.seq_length) def construct(self, input_mask): - if not self.input_mask_from_dataset: - input_mask = self.input_mask - attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) return attention_mask @@ -840,9 +803,6 @@ class BertModel(nn.Cell): config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 - self.input_mask_from_dataset = config.input_mask_from_dataset - self.token_type_ids_from_dataset = config.token_type_ids_from_dataset - self.batch_size = config.batch_size self.seq_length = config.seq_length self.hidden_size = config.hidden_size self.num_hidden_layers = config.num_hidden_layers @@ -850,12 +810,7 @@ class BertModel(nn.Cell): self.token_type_ids = None self.last_idx = self.num_hidden_layers - 1 - output_embedding_shape = [self.batch_size, self.seq_length, - self.embedding_size] - - if not self.token_type_ids_from_dataset: - self.token_type_ids = initializer( - "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + output_embedding_shape = [-1, self.seq_length, self.embedding_size] self.bert_embedding_lookup = EmbeddingLookup( vocab_size=config.vocab_size, @@ -876,7 +831,6 @@ class BertModel(nn.Cell): dropout_prob=config.hidden_dropout_prob) self.bert_encoder = BertTransformer( - batch_size=self.batch_size, hidden_size=self.hidden_size, seq_length=self.seq_length, num_attention_heads=config.num_attention_heads, @@ -905,8 +859,6 @@ class BertModel(nn.Cell): def construct(self, input_ids, token_type_ids, input_mask): """Bidirectional Encoder Representations from Transformers.""" # embedding - if not self.token_type_ids_from_dataset: - token_type_ids = self.token_type_ids word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) @@ -921,9 +873,10 @@ class BertModel(nn.Cell): sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) # pooler + batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, (0, 0, 0), - (self.batch_size, 1, self.hidden_size), + (batch_size, 1, self.hidden_size), (1, 1, 1)) first_token = self.squeeze_1(sequence_slice) pooled_output = self.dense(first_token) diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index 32c8bb86a80..7faca2e3bd5 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -19,6 +19,7 @@ from easydict import EasyDict as edict import mindspore.common.dtype as mstype from .bert_model import BertConfig cfg = edict({ + 'batch_size': 32, 'bert_network': 'base', 'loss_scale_value': 65536, 'scale_factor': 2, @@ -57,7 +58,6 @@ large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which ''' if cfg.bert_network == 'base': bert_net_cfg = BertConfig( - batch_size=64, seq_length=128, vocab_size=21128, hidden_size=768, @@ -71,14 +71,11 @@ if cfg.bert_network == 'base': type_vocab_size=2, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16 ) if cfg.bert_network == 'nezha': bert_net_cfg = BertConfig( - batch_size=96, seq_length=128, vocab_size=21128, hidden_size=1024, @@ -92,14 +89,11 @@ if cfg.bert_network == 'nezha': type_vocab_size=2, initializer_range=0.02, use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16 ) if cfg.bert_network == 'large': bert_net_cfg = BertConfig( - batch_size=24, seq_length=512, vocab_size=30522, hidden_size=1024, @@ -113,8 +107,6 @@ if cfg.bert_network == 'large': type_vocab_size=2, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16 ) diff --git a/model_zoo/official/nlp/bert/src/dataset.py b/model_zoo/official/nlp/bert/src/dataset.py index 868dab5bca7..3c1a8c8b90e 100644 --- a/model_zoo/official/nlp/bert/src/dataset.py +++ b/model_zoo/official/nlp/bert/src/dataset.py @@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C from mindspore import log as logger -from .config import bert_net_cfg +from .config import cfg def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None): @@ -46,7 +46,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, ds = ds.map(operations=type_cast_op, input_columns="input_mask") ds = ds.map(operations=type_cast_op, input_columns="input_ids") # apply batch operations - ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) + ds = ds.batch(cfg.batch_size, drop_remainder=True) logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeat count: {}".format(ds.get_repeat_count())) return ds diff --git a/model_zoo/official/nlp/bert/src/finetune_eval_config.py b/model_zoo/official/nlp/bert/src/finetune_eval_config.py index 4a9f05a3fc0..7a387d5aba3 100644 --- a/model_zoo/official/nlp/bert/src/finetune_eval_config.py +++ b/model_zoo/official/nlp/bert/src/finetune_eval_config.py @@ -22,6 +22,7 @@ import mindspore.common.dtype as mstype from .bert_model import BertConfig optimizer_cfg = edict({ + 'batch_size': 16, 'optimizer': 'Lamb', 'AdamWeightDecay': edict({ 'learning_rate': 2e-5, @@ -45,7 +46,6 @@ optimizer_cfg = edict({ }) bert_net_cfg = BertConfig( - batch_size=16, seq_length=128, vocab_size=21128, hidden_size=768, @@ -59,8 +59,6 @@ bert_net_cfg = BertConfig( type_vocab_size=2, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16, ) diff --git a/model_zoo/official/nlp/bert/src/finetune_eval_model.py b/model_zoo/official/nlp/bert/src/finetune_eval_model.py index 5cfb2020820..7ee101438c8 100644 --- a/model_zoo/official/nlp/bert/src/finetune_eval_model.py +++ b/model_zoo/official/nlp/bert/src/finetune_eval_model.py @@ -107,7 +107,7 @@ class BertNERModel(nn.Cell): self.reshape = P.Reshape() self.shape = (-1, config.hidden_size) self.use_crf = use_crf - self.origin_shape = (config.batch_size, config.seq_length, self.num_labels) + self.origin_shape = (-1, config.seq_length, self.num_labels) def construct(self, input_ids, input_mask, token_type_id): """Return the final logits as the results of log_softmax.""" diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py index 13bb8cef99d..adc88dd7e91 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py @@ -41,11 +41,10 @@ DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" -def get_config(version='base', batch_size=1): +def get_config(version='base'): """get config""" if version == 'base': bert_config = BertConfig( - batch_size=batch_size, seq_length=128, vocab_size=21136, hidden_size=768, @@ -59,13 +58,10 @@ def get_config(version='base', batch_size=1): type_vocab_size=2, initializer_range=0.02, use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float32) elif version == 'large': bert_config = BertConfig( - batch_size=batch_size, seq_length=128, vocab_size=21136, hidden_size=1024, @@ -79,12 +75,10 @@ def get_config(version='base', batch_size=1): type_vocab_size=2, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16) else: - bert_config = BertConfig(batch_size=batch_size) + bert_config = BertConfig() return bert_config @@ -186,8 +180,7 @@ def test_bert_performance(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) ds, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True) version = os.getenv('VERSION', 'large') - batch_size = 16 - config = get_config(version=version, batch_size=batch_size) + config = get_config(version=version) netwithloss = BertNetworkWithLoss(config, True) lr = BertLearningRate(decay_steps=sink_size * new_repeat_count, diff --git a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py index a8bada0257b..5d4c4a2297e 100644 --- a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py @@ -41,11 +41,10 @@ DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" -def get_config(version='base', batch_size=1): +def get_config(version='base'): """get config""" if version == 'base': bert_config = BertConfig( - batch_size=batch_size, seq_length=128, vocab_size=21136, hidden_size=768, @@ -59,13 +58,10 @@ def get_config(version='base', batch_size=1): type_vocab_size=2, initializer_range=0.02, use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float32) elif version == 'large': bert_config = BertConfig( - batch_size=batch_size, seq_length=128, vocab_size=21136, hidden_size=1024, @@ -79,12 +75,10 @@ def get_config(version='base', batch_size=1): type_vocab_size=2, initializer_range=0.02, use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, dtype=mstype.float32, compute_type=mstype.float16) else: - bert_config = BertConfig(batch_size=batch_size) + bert_config = BertConfig() return bert_config @@ -185,8 +179,7 @@ def test_bert_percision(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) ds, new_repeat_count, _ = me_de_train_dataset() version = os.getenv('VERSION', 'large') - batch_size = 16 - config = get_config(version=version, batch_size=batch_size) + config = get_config(version=version) netwithloss = BertNetworkWithLoss(config, True) lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count, learning_rate=5e-5, end_learning_rate=1e-9,