From 6c03542eec0bf591b48a9993e530e0ce1e67fc67 Mon Sep 17 00:00:00 2001 From: seatea Date: Mon, 30 Mar 2020 11:46:53 +0800 Subject: [PATCH] Fix dtype bug for loss_scale and weight_decay. 1.Change dtype of scale to dtype of grad in loss_scale.py; 2.Change dtype of weight_decay to dtype of weight in optimizer.py. --- mindspore/nn/optim/optimizer.py | 2 +- mindspore/nn/wrap/loss_scale.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index e2b0cddb711..cd0ed93a101 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -84,7 +84,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: - return op_add((gradient, weight * F.scalar_to_array(weight_decay))) + return op_add((gradient, weight * weight_decay)) return gradient diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index f7c686f5350..a11c753eda4 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -32,7 +32,7 @@ reciprocal = P.Reciprocal() @_grad_scale.register("Tensor", "Tensor") def tensor_grad_scale(scale, grad): - return grad * reciprocal(scale) + return grad * F.cast(reciprocal(scale), F.dtype(grad)) class DynamicLossScaleUpdateCell(Cell):