forked from mindspore-Ecosystem/mindspore
!1809 modify adam optimizer and script of bert to match the patterns of fusion rule
Merge pull request !1809 from shibeiji/bert
This commit is contained in:
commit
18ecafcf0e
|
@ -67,9 +67,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
update = next_m / (eps + op_sqrt(next_v))
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
update = op_mul(weight_decay_tensor, param_fp32) + update
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
|
|
@ -26,6 +26,8 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
|
|||
.attr("data_format", "required", "str", "all") \
|
||||
.input(0, "output_backprop", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -261,7 +261,7 @@ class BertOutput(nn.Cell):
|
|||
def construct(self, hidden_status, input_tensor):
|
||||
output = self.dense(hidden_status)
|
||||
output = self.dropout(output)
|
||||
output = self.add(output, input_tensor)
|
||||
output = self.add(input_tensor, output)
|
||||
output = self.layernorm(output)
|
||||
return output
|
||||
|
||||
|
@ -832,8 +832,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
|
|||
if not self.input_mask_from_dataset:
|
||||
input_mask = self.input_mask
|
||||
|
||||
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
|
||||
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||
return attention_mask
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue