forked from mindspore-Ecosystem/mindspore
!24454 Fix FeedForward Reshape Error
Merge pull request !24454 from huangxinjing/fix_feedforward_reshape_master
This commit is contained in:
commit
f4668633eb
|
@ -381,9 +381,9 @@ class _Linear(Cell):
|
|||
x = self.matmul(x, weight)
|
||||
if self.has_bias:
|
||||
x = self.bias_add(x, self.cast(self.bias, self.dtype))
|
||||
output = P.Reshape()(x, out_shape)
|
||||
if self.activation_flag:
|
||||
output = self.activation(output)
|
||||
x = self.activation(x)
|
||||
output = P.Reshape()(x, out_shape)
|
||||
return output
|
||||
|
||||
def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
|
||||
|
|
|
@ -350,7 +350,7 @@ class FeedForward(Cell):
|
|||
if expert_num > 1:
|
||||
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
||||
strategy_bias=((ep, 1, mp), (mp,)),
|
||||
strategy_activation=((ep, mp),))
|
||||
strategy_activation=((ep, 1, mp),))
|
||||
else:
|
||||
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
||||
strategy_bias=((dp, mp), (mp,)),
|
||||
|
@ -787,7 +787,6 @@ class MultiHeadAttention(Cell):
|
|||
value_past=None, batch_valid_length=None):
|
||||
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||
value_past, batch_valid_length)
|
||||
batch_size = F.shape(attention_mask)[0]
|
||||
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
|
|
Loading…
Reference in New Issue