forked from mindspore-Ecosystem/mindspore
!213 use StridedSlice instead of Slice in absolute position embedding code in bert model
Merge pull request !213 from yoonlee666/master
This commit is contained in:
commit
865b8a79db
|
@ -194,7 +194,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.Slice()
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[max_position_embeddings,
|
||||
|
@ -216,7 +216,7 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
output += token_type_embeddings
|
||||
if not self.use_relative_positions:
|
||||
_, seq, width = self.shape
|
||||
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width])
|
||||
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
|
||||
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
||||
output += position_embeddings
|
||||
output = self.layernorm(output)
|
||||
|
|
Loading…
Reference in New Issue