!24454 Fix FeedForward Reshape Error

Merge pull request !24454 from huangxinjing/fix_feedforward_reshape_master
This commit is contained in:
i-robot 2021-09-30 07:37:48 +00:00 committed by Gitee
commit f4668633eb
2 changed files with 3 additions and 4 deletions

View File

@ -381,9 +381,9 @@ class _Linear(Cell):
x = self.matmul(x, weight)
if self.has_bias:
x = self.bias_add(x, self.cast(self.bias, self.dtype))
output = P.Reshape()(x, out_shape)
if self.activation_flag:
output = self.activation(output)
x = self.activation(x)
output = P.Reshape()(x, out_shape)
return output
def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):

View File

@ -350,7 +350,7 @@ class FeedForward(Cell):
if expert_num > 1:
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
strategy_bias=((ep, 1, mp), (mp,)),
strategy_activation=((ep, mp),))
strategy_activation=((ep, 1, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
@ -787,7 +787,6 @@ class MultiHeadAttention(Cell):
value_past=None, batch_valid_length=None):
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
batch_size = F.shape(attention_mask)[0]
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor,