bugfix tinybert

This commit is contained in:
yoonlee666 2020-09-25 18:02:26 +08:00
parent a95ed7e8e0
commit a2af684dd4
1 changed files with 4 additions and 2 deletions

View File

@ -844,9 +844,10 @@ class BertModel(nn.Cell):
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(-1, 1, self.hidden_size),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell):
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(-1, 1, self.hidden_size),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)