Fix bugs of adam and lamb optimizer
This commit is contained in:
parent
5a63dac00f
commit
759748dc06
|
@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
success = True
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
|
@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
success = F.depend(success, next_param)
|
||||
return success
|
||||
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
|
|
|
@ -104,11 +104,11 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
|
|||
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_param = F.depend(next_param, F.assign(param, next_param))
|
||||
next_param = F.depend(next_param, F.assign(m, next_m))
|
||||
next_param = F.depend(next_param, F.assign(v, next_v))
|
||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
|
||||
return next_param
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue