forked from mindspore-Ecosystem/mindspore
extract bert embedding tables in construct
This commit is contained in:
parent
4719a66be6
commit
3ec3f038ad
|
@ -805,7 +805,6 @@ class BertModel(nn.Cell):
|
|||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
use_one_hot=use_one_hot_embeddings)
|
||||
self.embedding_tables = self.bert_embedding_lookup.embedding_table
|
||||
|
||||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
|
@ -847,7 +846,7 @@ class BertModel(nn.Cell):
|
|||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""Bidirectional Encoder Representations from Transformers."""
|
||||
# embedding
|
||||
embedding_tables = self.embedding_tables
|
||||
embedding_tables = self.bert_embedding_lookup.embedding_table
|
||||
word_embeddings = self.bert_embedding_lookup(input_ids)
|
||||
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
|
|
Loading…
Reference in New Issue