forked from OSSInnovation/mindspore
!14617 Fix Gelu in select ops
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
fc1c0e0952
|
@ -415,7 +415,7 @@ class GELU(Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(GELU, self).__init__()
|
||||
self.gelu = _selected_ops.Gelu()
|
||||
self.gelu = _selected_ops.GeLU()
|
||||
|
||||
def construct(self, x):
|
||||
return self.gelu(x)
|
||||
|
@ -458,7 +458,7 @@ class FastGelu(Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(FastGelu, self).__init__()
|
||||
self.fast_gelu = _selected_ops.FastGelu()
|
||||
self.fast_gelu = _selected_ops.FastGeLU()
|
||||
|
||||
def construct(self, x):
|
||||
return self.fast_gelu(x)
|
||||
|
|
|
@ -73,13 +73,13 @@ class Tanh:
|
|||
|
||||
|
||||
@op_selector
|
||||
class Gelu:
|
||||
class GeLU:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class FastGelu:
|
||||
class FastGeLU:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
|
|
@ -499,7 +499,7 @@ class FeedForward(nn.Cell):
|
|||
|
||||
self.layernorm = LayerNorm(in_channels=in_channels)
|
||||
self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout)
|
||||
self.gelu_act = P.Gelu()
|
||||
self.gelu_act = P.GeLU()
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout)
|
||||
self.use_dropout = hidden_dropout > 0
|
||||
self.reshape = P.Reshape()
|
||||
|
|
Loading…
Reference in New Issue