modify adam optimizer and script of bert to match the patterns of fusion rule
This commit is contained in:
parent
5cba231ba9
commit
178952afbc
|
@ -67,9 +67,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
||||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||||
- beta2, op_square(gradient_fp32))
|
- beta2, op_square(gradient_fp32))
|
||||||
|
|
||||||
update = next_m / (op_sqrt(next_v) + eps)
|
update = next_m / (eps + op_sqrt(next_v))
|
||||||
if decay_flag:
|
if decay_flag:
|
||||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
update = op_mul(weight_decay_tensor, param_fp32) + update
|
||||||
|
|
||||||
update_with_lr = op_mul(lr, update)
|
update_with_lr = op_mul(lr, update)
|
||||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||||
|
|
|
@ -26,6 +26,8 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
|
||||||
.attr("data_format", "required", "str", "all") \
|
.attr("data_format", "required", "str", "all") \
|
||||||
.input(0, "output_backprop", False, "required", "all") \
|
.input(0, "output_backprop", False, "required", "all") \
|
||||||
.output(0, "output", False, "required", "all") \
|
.output(0, "output", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
|
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
|
@ -261,7 +261,7 @@ class BertOutput(nn.Cell):
|
||||||
def construct(self, hidden_status, input_tensor):
|
def construct(self, hidden_status, input_tensor):
|
||||||
output = self.dense(hidden_status)
|
output = self.dense(hidden_status)
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
output = self.add(output, input_tensor)
|
output = self.add(input_tensor, output)
|
||||||
output = self.layernorm(output)
|
output = self.layernorm(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -832,8 +832,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||||
if not self.input_mask_from_dataset:
|
if not self.input_mask_from_dataset:
|
||||||
input_mask = self.input_mask
|
input_mask = self.input_mask
|
||||||
|
|
||||||
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||||
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue