diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_model.py b/mindspore/model_zoo/Bert_NEZHA/bert_model.py index f20c57dd751..d7f9355b3c4 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_model.py +++ b/mindspore/model_zoo/Bert_NEZHA/bert_model.py @@ -165,6 +165,7 @@ class EmbeddingPostprocessor(nn.Cell): def __init__(self, embedding_size, embedding_shape, + use_relative_positions=False, use_token_type=False, token_type_vocab_size=16, use_one_hot_embeddings=False, @@ -192,6 +193,13 @@ class EmbeddingPostprocessor(nn.Cell): self.layernorm = nn.LayerNorm(embedding_size) self.dropout = nn.Dropout(1 - dropout_prob) self.gather = P.GatherV2() + self.use_relative_positions = use_relative_positions + self.slice = P.Slice() + self.full_position_embeddings = Parameter(initializer + (TruncatedNormal(initializer_range), + [max_position_embeddings, + embedding_size]), + name='full_position_embeddings') def construct(self, token_type_ids, word_embeddings): output = word_embeddings @@ -206,6 +214,11 @@ class EmbeddingPostprocessor(nn.Cell): token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) token_type_embeddings = self.reshape(token_type_embeddings, self.shape) output += token_type_embeddings + if not self.use_relative_positions: + _, seq, width = self.shape + position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width]) + position_embeddings = self.reshape(position_embeddings, (1, seq, width)) + output += position_embeddings output = self.layernorm(output) output = self.dropout(output) return output @@ -853,6 +866,7 @@ class BertModel(nn.Cell): self.bert_embedding_postprocessor = EmbeddingPostprocessor( embedding_size=self.embedding_size, embedding_shape=output_embedding_shape, + use_relative_positions=config.use_relative_positions, use_token_type=True, token_type_vocab_size=config.type_vocab_size, use_one_hot_embeddings=use_one_hot_embeddings, diff --git a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py index 9cc11997e68..5b6268505b6 100644 --- a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py @@ -103,9 +103,9 @@ def me_de_train_dataset(): """test me de train dataset""" # apply repeat operations repeat_count = 1 - ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", + ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", - "masked_lm_ids", "masked_lm_weights"]) + "masked_lm_ids", "masked_lm_weights"], shuffle=False) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)