forked from mindspore-Ecosystem/mindspore
fix performance of bert
This commit is contained in:
parent
63479f8e7c
commit
04bc2a938e
|
@ -686,7 +686,7 @@ bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &va
|
|||
MS_EXCEPTION_IF_NULL(equiv1_node);
|
||||
auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
|
||||
MS_EXCEPTION_IF_NULL(equiv2_node);
|
||||
return equiv1_node == equiv2_node;
|
||||
return *equiv1_node == *equiv2_node;
|
||||
}
|
||||
|
||||
AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
|
||||
|
|
|
@ -180,7 +180,7 @@ class Lamb(Optimizer):
|
|||
beta2=0.999,
|
||||
eps=1e-6,
|
||||
weight_decay=0.0,
|
||||
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
|
||||
|
||||
super(Lamb, self).__init__(start_learning_rate, params)
|
||||
if self.is_group:
|
||||
|
|
|
@ -191,8 +191,8 @@ def get_bprop_mul(self):
|
|||
mul_func = P.Mul()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
bc_dx = mul_func(dout, y)
|
||||
bc_dy = mul_func(dout, x)
|
||||
bc_dx = mul_func(y, dout)
|
||||
bc_dy = mul_func(x, dout)
|
||||
return binop_grad_common(x, y, bc_dx, bc_dy)
|
||||
return bprop
|
||||
|
||||
|
|
Loading…
Reference in New Issue