forked from OSSInnovation/mindspore
bugfix tinybert
This commit is contained in:
parent
a95ed7e8e0
commit
a2af684dd4
|
@ -844,9 +844,10 @@ class BertModel(nn.Cell):
|
|||
attention_mask)
|
||||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(-1, 1, self.hidden_size),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
|
@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell):
|
|||
attention_mask)
|
||||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(-1, 1, self.hidden_size),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
|
|
Loading…
Reference in New Issue