forked from mindspore-Ecosystem/mindspore
Fix dtype bug for loss_scale and weight_decay.
1.Change dtype of scale to dtype of grad in loss_scale.py; 2.Change dtype of weight_decay to dtype of weight in optimizer.py.
This commit is contained in:
parent
930a1fb0a8
commit
6c03542eec
|
@ -84,7 +84,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
|
|||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((gradient, weight * F.scalar_to_array(weight_decay)))
|
||||
return op_add((gradient, weight * weight_decay))
|
||||
return gradient
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ reciprocal = P.Reciprocal()
|
|||
|
||||
@_grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
|
||||
class DynamicLossScaleUpdateCell(Cell):
|
||||
|
|
Loading…
Reference in New Issue