!6895 bugfix tinybert

Merge pull request !6895 from yoonlee666/master
This commit is contained in:
mindspore-ci-bot 2020-09-27 09:28:06 +08:00 committed by Gitee
commit 05e55ef467
1 changed files with 4 additions and 2 deletions

View File

@ -844,9 +844,10 @@ class BertModel(nn.Cell):
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(-1, 1, self.hidden_size),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell):
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(-1, 1, self.hidden_size),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)