forked from mindspore-Ecosystem/mindspore
!17755 fix the reduction args of Loss operator, line too long and other warning problems.
From: @wangshuide2020 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
aeda8e9fff
|
@ -47,11 +47,9 @@ class Loss(Cell):
|
|||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
super(Loss, self).__init__()
|
||||
if reduction is None:
|
||||
reduction = 'none'
|
||||
|
||||
if reduction not in ('mean', 'sum', 'none'):
|
||||
raise ValueError(f"reduction method for {reduction.lower()} is not supported")
|
||||
raise ValueError(f"reduction method for {reduction} is not supported")
|
||||
|
||||
self.average = True
|
||||
self.reduce = True
|
||||
|
|
|
@ -58,7 +58,8 @@ class LARS(Optimizer):
|
|||
.. math::
|
||||
|
||||
\begin{array}{ll} \\
|
||||
\lambda = \frac{\theta \text{ * } || \omega || }{|| g_{t} || \text{ + } \delta \text{ * } || \omega || } \\
|
||||
\lambda = \frac{\theta \text{ * } || \omega || } \\
|
||||
{|| g_{t} || \text{ + } \delta \text{ * } || \omega || } \\
|
||||
\lambda =
|
||||
\begin{cases}
|
||||
\min(\frac{\lambda}{\alpha }, 1)
|
||||
|
@ -70,7 +71,7 @@ class LARS(Optimizer):
|
|||
\end{array}
|
||||
|
||||
:math:`\theta` represents `coefficient`, :math:`\omega` represents `parameters`, :math:`g` represents `gradients`,
|
||||
:math:`t` represents updateing step, :math:`\delta` represents `weight_decay`,
|
||||
:math:`t` represents updating step, :math:`\delta` represents `weight_decay`,
|
||||
:math:`\alpha` represents `learning_rate`, :math:`clip` represents `use_clip`.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -107,7 +107,8 @@ class LazyAdam(Optimizer):
|
|||
r"""
|
||||
This optimizer will apply a lazy adam algorithm when gradient is sparse.
|
||||
|
||||
The original adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
The original adam algorithm is proposed in
|
||||
`Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
|
|
|
@ -138,6 +138,7 @@ def get_bprop_BatchNormFold(self):
|
|||
|
||||
@bprop_getters.register(P.BNTrainingReduce)
|
||||
def get_bprop_BNTrainingReduce(self):
|
||||
"""Generate bprop for BNTrainingReduce for Ascend"""
|
||||
def bprop(x, out, dout):
|
||||
return (zeros_like(x),)
|
||||
|
||||
|
|
Loading…
Reference in New Issue