forked from mindspore-Ecosystem/mindspore
!149 Use TFRecordDataset instead of StorageDataset in Bert model integration test and add absolute position embedding code in bert model
Merge pull request !149 from yoonlee666/master
This commit is contained in:
commit
3f5136302a
|
@ -165,6 +165,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
def __init__(self,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_relative_positions=False,
|
||||
use_token_type=False,
|
||||
token_type_vocab_size=16,
|
||||
use_one_hot_embeddings=False,
|
||||
|
@ -192,6 +193,13 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.layernorm = nn.LayerNorm(embedding_size)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.Slice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[max_position_embeddings,
|
||||
embedding_size]),
|
||||
name='full_position_embeddings')
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
output = word_embeddings
|
||||
|
@ -206,6 +214,11 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
|
||||
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
|
||||
output += token_type_embeddings
|
||||
if not self.use_relative_positions:
|
||||
_, seq, width = self.shape
|
||||
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width])
|
||||
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
||||
output += position_embeddings
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
@ -853,6 +866,7 @@ class BertModel(nn.Cell):
|
|||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_relative_positions=config.use_relative_positions,
|
||||
use_token_type=True,
|
||||
token_type_vocab_size=config.type_vocab_size,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
|
|
|
@ -103,9 +103,9 @@ def me_de_train_dataset():
|
|||
"""test me de train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = 1
|
||||
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"next_sentence_labels", "masked_lm_positions",
|
||||
"masked_lm_ids", "masked_lm_weights"])
|
||||
"masked_lm_ids", "masked_lm_weights"], shuffle=False)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
|
|
Loading…
Reference in New Issue