From dc7cc66b3ddbe3f7ff4cbaf8d6069920c48e0cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=98=89=E7=90=AA?= Date: Wed, 29 Jul 2020 15:22:48 +0800 Subject: [PATCH] add optimizer formula and verification --- mindspore/nn/optim/momentum.py | 13 +++++++++++++ mindspore/nn/optim/sgd.py | 22 +++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index a1730e7c679..5ce7bd28232 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -53,6 +53,19 @@ class Momentum(Optimizer): To improve parameter groups performance, the customized order of parameters can be supported. + .. math:: + v_{t} = v_{t-1} \ast u + gradients + + If use_nesterov is True: + .. math:: + p_{t} = grad \ast lr + v_{t} \ast u \ast lr + + If use_nesterov is Flase: + .. math:: + p_{t} = lr \ast v_{t} + + Here: where grad, lr, p, v and u denote the gradients, learning_rate, parameter, accum, and momentum respectively. + Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index f093200906b..406ad9f48dc 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -46,6 +46,21 @@ class SGD(Optimizer): To improve parameter groups performance, the customized order of parameters can be supported. + .. math:: + v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening) + + If nesterov is True: + .. math:: + p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1}) + + If nesterov is Flase: + .. math:: + p_{t+1} = p_{t} - lr \ast v_{t+1} + + To be notice, for the first step, v_{t+1} = gradient + + Here : where p, v and u denote the parameters, accum, and momentum respectively. + Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", @@ -74,7 +89,8 @@ class SGD(Optimizer): momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0. dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0. weight_decay (float): Weight decay (L2 penalty). It should be in range [0.0, 1.0]. Default: 0.0. - nesterov (bool): Enables the Nesterov momentum. Default: False. + nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must greater then 0, + and dampening must equal to 1. Default: False. loss_scale (float): A floating point value for the loss scale. Should be not less than 1.0. Default: 1.0. Inputs: @@ -118,6 +134,10 @@ class SGD(Optimizer): if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0," + "but got momentum {}, dampening {}".format(momentum, dampening)) + if isinstance(dampening, int): dampening = float(dampening) if not isinstance(dampening, float):