!51220 fix SGD weight decay (master): support group and raise error when dynamic

Merge pull request !51220 from wangnan39/fix_sgd_weightdecay
This commit is contained in:
i-robot 2023-03-25 03:16:47 +00:00 committed by Gitee
commit 786a8523ce
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 140 additions and 15 deletions

View File

@ -34,7 +34,7 @@ mindspore.nn.SGD
.. include:: mindspore.nn.optim_group_param.rst
.. include:: mindspore.nn.optim_group_lr.rst
- **weight_decay** - 目前不支持通过参数分组使用不同的weight_decay
- **weight_decay** - 可选。如果键中存在"weight_decay",则使用对应的值作为权重衰减值。如果没有,则使用优化器中配置的 `weight_decay` 作为权重衰减值。当前 `weight_decay` 仅支持float类型不支持动态变化
.. include:: mindspore.nn.optim_group_gc.rst
.. include:: mindspore.nn.optim_group_order.rst

View File

@ -27,8 +27,8 @@ from mindspore.nn.optim.optimizer import opt_init_args_register
_sgd_opt = C.MultitypeFuncGraph("sgd_opt")
@_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
@_sgd_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Function")
def _tensor_run_opt_ext(momentum, learning_rate, gradient, weight, accum, stat, opt):
"""Apply sgd optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
@ -76,7 +76,9 @@ class SGD(Optimizer):
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
- weight_decay: Using different `weight_decay` by grouping parameters is currently not supported.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
decay must be float, dynamic weight decay is currently not supported.
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
@ -164,7 +166,7 @@ class SGD(Optimizer):
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("For 'SGD', the argument 'momentum' must be at least 0.0, "
"but got {}".format(momentum))
"but got {}.".format(momentum))
if isinstance(dampening, int):
dampening = float(dampening)
@ -177,9 +179,6 @@ class SGD(Optimizer):
"but got 'dampening' {}".format(dampening))
self.dampening = dampening
if isinstance(weight_decay, int):
weight_decay = float(weight_decay)
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
if nesterov and (momentum <= 0.0 or dampening != 0.0):
@ -187,7 +186,14 @@ class SGD(Optimizer):
"equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
self.nesterov = nesterov
self.opt = P.SGD(dampening, weight_decay, nesterov)
if self.dynamic_weight_decay:
raise TypeError("For 'SGD', dynamic weight decay is currently not supported, the argument 'weight_decay' "
"or 'weight_decay' set in grouped 'params' must be float or int type.")
if hasattr(self, "group_weight_decay") and self.group_weight_decay:
self.opt = tuple(P.SGD(dampening, wd, nesterov) for wd in self.group_weight_decay)
else:
self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.accum = self._parameters.clone(prefix="accum", init='zeros')
@ -203,9 +209,9 @@ class SGD(Optimizer):
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum),
lr, gradients, params, accum, stat)
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum),
lr, gradients, params, accum, stat, self.opt)
else:
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum, lr),
gradients, params, accum, stat)
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum, lr),
gradients, params, accum, stat, self.opt)
return success

View File

@ -110,6 +110,8 @@ def build_network(opt_config, net, is_group=False, loss_fn=nn.MSELoss(reduction=
params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
elif opt_config['name'] == 'adamax':
params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}]
elif opt_config['name'] == 'SGD':
params = [{'params': fc1_params, 'weight_decay': 0.2}, {'params': fc2_params}]
else:
params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
else:
@ -126,13 +128,14 @@ def build_network(opt_config, net, is_group=False, loss_fn=nn.MSELoss(reduction=
elif opt_config['name'] == 'adamax':
net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'],
beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0)
elif opt_config['name'] == 'SGD':
net_opt = nn.SGD(params, weight_decay=opt_config['weight_decay'], dampening=0.3, momentum=0.1)
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
data, label = make_fake_data()
for i in range(20):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())
if opt_config['name'] == 'ASGD':
if opt_config['name'] == 'ASGD' or opt_config['name'] == 'SGD':
return np.array(losses), net_opt
return np.array(losses)
@ -227,3 +230,13 @@ no_default_group_fc1_weight_asgd = np.array([[-32.526627, -32.29401, -32.8416, -
no_default_group_fc1_bias_asgd = np.array([-15.838092, -16.811989, -16.078112, -14.289094], dtype=np.float32)
no_default_group_fc2_weight_asgd = np.array([[1288.7146, 1399.3041, 1292.8445, 1121.4629]], dtype=np.float32)
no_default_group_fc2_bias_asgd = np.array([18.513494], dtype=np.float32)
default_fc1_weight_sgd = np.array([[-6.6273242e-02, -3.9511207e-02, -1.0251881e-01, -1.2807587e-01,
-6.9634348e-02, -1.0375493e-01, -1.2083838e-01, -9.7173907e-02],
[-1.8068390e-02, -7.8982085e-02, -1.3175679e-02, 2.0881524e-04,
-6.4472459e-02, 7.9219900e-03, -2.8659783e-02, -6.9297753e-02],
[-2.5218798e-02, -3.6950763e-02, -4.2106784e-03, 2.9642319e-02,
1.0740350e-02, -6.0375791e-02, 5.5906363e-03, 2.0822065e-02],
[-1.1401306e+01, -1.1416125e+01, -1.1386261e+01, -1.1366054e+01,
-1.1324347e+01, -1.1358459e+01, -1.1398650e+01, -1.1339014e+01]], dtype=np.float32)
default_fc2_weight_sgd = np.array([[-0.5055597, -0.5255496, -0.52437556, 1.0779992]], dtype=np.float32)

View File

@ -0,0 +1,41 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from .optimizer_utils import FakeNet, build_network
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_default_asgd(mode):
"""
Feature: Test SGD optimizer
Description: Test SGD with group weight decay
Expectation: Parameters conform to preset values.
"""
from .optimizer_utils import default_fc1_weight_sgd, default_fc2_weight_sgd
context.set_context(mode=mode)
config = {'name': 'SGD', "weight_decay": 0.1}
_, cells = build_network(config, FakeNet(), is_group=True)
assert np.allclose(cells.accum[0].asnumpy(), default_fc1_weight_sgd, atol=1.e-4)
assert np.allclose(cells.accum[2].asnumpy(), default_fc2_weight_sgd, atol=1.e-4)

View File

@ -0,0 +1,65 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test SGD """
import numpy as np
import pytest
import mindspore as ms
from mindspore import Tensor, Parameter
import mindspore.ops as ops
import mindspore.nn as nn
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
self.matmul = ops.MatMul()
self.biasadd = ops.BiasAdd()
def construct(self, x):
x = self.biasadd(self.matmul(x, self.weight), self.bias)
return x
class WeightDecaySchdule(nn.Cell):
def __init__(self):
super(WeightDecaySchdule, self).__init__()
self.weight_decay_list = Tensor([0.001, 0.001, 0.1], ms.float32)
def construct(self, global_step):
return self.weight_decay_list[global_step]
def test_sgd_dynamic_weightdecay():
"""
Feature: Test SGD optimizer.
Description: Test if error is raised when weight decay is dynamic.
Expectation: ValueError is raised.
"""
net = Net()
params = net.trainable_params()
group_params = [{'params': [params[0]], 'weight_decay': WeightDecaySchdule()}, {'params': [params[1]]}]
weight_decay_error = "For 'SGD', dynamic weight decay is currently not supported, the argument 'weight_decay' " \
"or 'weight_decay' set in grouped 'params' must be float or int type."
with pytest.raises(TypeError, match=weight_decay_error):
nn.SGD(params, learning_rate=0.1, weight_decay=WeightDecaySchdule())
with pytest.raises(TypeError, match=weight_decay_error):
nn.SGD(group_params, learning_rate=0.1)