!1809 modify adam optimizer and script of bert to match the patterns of fusion rule

Merge pull request !1809 from shibeiji/bert
This commit is contained in:
mindspore-ci-bot 2020-06-04 14:44:52 +08:00 committed by Gitee
commit 18ecafcf0e
3 changed files with 6 additions and 5 deletions

View File

@ -67,9 +67,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (op_sqrt(next_v) + eps)
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update = op_mul(weight_decay_tensor, param_fp32) + update
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

View File

@ -26,6 +26,8 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
.attr("data_format", "required", "str", "all") \
.input(0, "output_backprop", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
.get_op_info()

View File

@ -261,7 +261,7 @@ class BertOutput(nn.Cell):
def construct(self, hidden_status, input_tensor):
output = self.dense(hidden_status)
output = self.dropout(output)
output = self.add(output, input_tensor)
output = self.add(input_tensor, output)
output = self.layernorm(output)
return output
@ -832,8 +832,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
if not self.input_mask_from_dataset:
input_mask = self.input_mask
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
return attention_mask