!51220 fix SGD weight decay (master): support group and raise error when dynamic
Merge pull request !51220 from wangnan39/fix_sgd_weightdecay
This commit is contained in:
commit
786a8523ce
|
@ -34,7 +34,7 @@ mindspore.nn.SGD
|
|||
.. include:: mindspore.nn.optim_group_param.rst
|
||||
.. include:: mindspore.nn.optim_group_lr.rst
|
||||
|
||||
- **weight_decay** - 目前不支持通过参数分组使用不同的weight_decay。
|
||||
- **weight_decay** - 可选。如果键中存在"weight_decay",则使用对应的值作为权重衰减值。如果没有,则使用优化器中配置的 `weight_decay` 作为权重衰减值。当前 `weight_decay` 仅支持float类型,不支持动态变化。
|
||||
|
||||
.. include:: mindspore.nn.optim_group_gc.rst
|
||||
.. include:: mindspore.nn.optim_group_order.rst
|
||||
|
|
|
@ -27,8 +27,8 @@ from mindspore.nn.optim.optimizer import opt_init_args_register
|
|||
_sgd_opt = C.MultitypeFuncGraph("sgd_opt")
|
||||
|
||||
|
||||
@_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
|
||||
@_sgd_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Function")
|
||||
def _tensor_run_opt_ext(momentum, learning_rate, gradient, weight, accum, stat, opt):
|
||||
"""Apply sgd optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
|
||||
|
@ -76,7 +76,9 @@ class SGD(Optimizer):
|
|||
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Using different `weight_decay` by grouping parameters is currently not supported.
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay must be float, dynamic weight decay is currently not supported.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -164,7 +166,7 @@ class SGD(Optimizer):
|
|||
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("For 'SGD', the argument 'momentum' must be at least 0.0, "
|
||||
"but got {}".format(momentum))
|
||||
"but got {}.".format(momentum))
|
||||
|
||||
if isinstance(dampening, int):
|
||||
dampening = float(dampening)
|
||||
|
@ -177,9 +179,6 @@ class SGD(Optimizer):
|
|||
"but got 'dampening' {}".format(dampening))
|
||||
self.dampening = dampening
|
||||
|
||||
if isinstance(weight_decay, int):
|
||||
weight_decay = float(weight_decay)
|
||||
|
||||
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
|
||||
|
||||
if nesterov and (momentum <= 0.0 or dampening != 0.0):
|
||||
|
@ -187,7 +186,14 @@ class SGD(Optimizer):
|
|||
"equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
|
||||
self.nesterov = nesterov
|
||||
|
||||
self.opt = P.SGD(dampening, weight_decay, nesterov)
|
||||
if self.dynamic_weight_decay:
|
||||
raise TypeError("For 'SGD', dynamic weight decay is currently not supported, the argument 'weight_decay' "
|
||||
"or 'weight_decay' set in grouped 'params' must be float or int type.")
|
||||
|
||||
if hasattr(self, "group_weight_decay") and self.group_weight_decay:
|
||||
self.opt = tuple(P.SGD(dampening, wd, nesterov) for wd in self.group_weight_decay)
|
||||
else:
|
||||
self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
|
||||
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.accum = self._parameters.clone(prefix="accum", init='zeros')
|
||||
|
@ -203,9 +209,9 @@ class SGD(Optimizer):
|
|||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum),
|
||||
lr, gradients, params, accum, stat)
|
||||
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum),
|
||||
lr, gradients, params, accum, stat, self.opt)
|
||||
else:
|
||||
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.opt, self.momentum, lr),
|
||||
gradients, params, accum, stat)
|
||||
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum, lr),
|
||||
gradients, params, accum, stat, self.opt)
|
||||
return success
|
||||
|
|
|
@ -110,6 +110,8 @@ def build_network(opt_config, net, is_group=False, loss_fn=nn.MSELoss(reduction=
|
|||
params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
|
||||
elif opt_config['name'] == 'adamax':
|
||||
params = [{'params': fc1_params, 'lr': 0.0018}, {'params': fc2_params, 'lr': 0.0022}]
|
||||
elif opt_config['name'] == 'SGD':
|
||||
params = [{'params': fc1_params, 'weight_decay': 0.2}, {'params': fc2_params}]
|
||||
else:
|
||||
params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}]
|
||||
else:
|
||||
|
@ -126,13 +128,14 @@ def build_network(opt_config, net, is_group=False, loss_fn=nn.MSELoss(reduction=
|
|||
elif opt_config['name'] == 'adamax':
|
||||
net_opt = AdaMax(params, learning_rate=opt_config['lr'], beta1=opt_config['beta1'],
|
||||
beta2=opt_config['beta2'], eps=opt_config['eps'], weight_decay=0.0)
|
||||
|
||||
elif opt_config['name'] == 'SGD':
|
||||
net_opt = nn.SGD(params, weight_decay=opt_config['weight_decay'], dampening=0.3, momentum=0.1)
|
||||
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
|
||||
data, label = make_fake_data()
|
||||
for i in range(20):
|
||||
loss = trainonestepcell(data[i], label[i])
|
||||
losses.append(loss.asnumpy())
|
||||
if opt_config['name'] == 'ASGD':
|
||||
if opt_config['name'] == 'ASGD' or opt_config['name'] == 'SGD':
|
||||
return np.array(losses), net_opt
|
||||
return np.array(losses)
|
||||
|
||||
|
@ -227,3 +230,13 @@ no_default_group_fc1_weight_asgd = np.array([[-32.526627, -32.29401, -32.8416, -
|
|||
no_default_group_fc1_bias_asgd = np.array([-15.838092, -16.811989, -16.078112, -14.289094], dtype=np.float32)
|
||||
no_default_group_fc2_weight_asgd = np.array([[1288.7146, 1399.3041, 1292.8445, 1121.4629]], dtype=np.float32)
|
||||
no_default_group_fc2_bias_asgd = np.array([18.513494], dtype=np.float32)
|
||||
|
||||
default_fc1_weight_sgd = np.array([[-6.6273242e-02, -3.9511207e-02, -1.0251881e-01, -1.2807587e-01,
|
||||
-6.9634348e-02, -1.0375493e-01, -1.2083838e-01, -9.7173907e-02],
|
||||
[-1.8068390e-02, -7.8982085e-02, -1.3175679e-02, 2.0881524e-04,
|
||||
-6.4472459e-02, 7.9219900e-03, -2.8659783e-02, -6.9297753e-02],
|
||||
[-2.5218798e-02, -3.6950763e-02, -4.2106784e-03, 2.9642319e-02,
|
||||
1.0740350e-02, -6.0375791e-02, 5.5906363e-03, 2.0822065e-02],
|
||||
[-1.1401306e+01, -1.1416125e+01, -1.1386261e+01, -1.1366054e+01,
|
||||
-1.1324347e+01, -1.1358459e+01, -1.1398650e+01, -1.1339014e+01]], dtype=np.float32)
|
||||
default_fc2_weight_sgd = np.array([[-0.5055597, -0.5255496, -0.52437556, 1.0779992]], dtype=np.float32)
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from .optimizer_utils import FakeNet, build_network
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_default_asgd(mode):
|
||||
"""
|
||||
Feature: Test SGD optimizer
|
||||
Description: Test SGD with group weight decay
|
||||
Expectation: Parameters conform to preset values.
|
||||
"""
|
||||
from .optimizer_utils import default_fc1_weight_sgd, default_fc2_weight_sgd
|
||||
context.set_context(mode=mode)
|
||||
config = {'name': 'SGD', "weight_decay": 0.1}
|
||||
_, cells = build_network(config, FakeNet(), is_group=True)
|
||||
assert np.allclose(cells.accum[0].asnumpy(), default_fc1_weight_sgd, atol=1.e-4)
|
||||
assert np.allclose(cells.accum[2].asnumpy(), default_fc2_weight_sgd, atol=1.e-4)
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test SGD """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, Parameter
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
""" Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
|
||||
self.matmul = ops.MatMul()
|
||||
self.biasadd = ops.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.biasadd(self.matmul(x, self.weight), self.bias)
|
||||
return x
|
||||
|
||||
|
||||
class WeightDecaySchdule(nn.Cell):
|
||||
def __init__(self):
|
||||
super(WeightDecaySchdule, self).__init__()
|
||||
self.weight_decay_list = Tensor([0.001, 0.001, 0.1], ms.float32)
|
||||
|
||||
def construct(self, global_step):
|
||||
return self.weight_decay_list[global_step]
|
||||
|
||||
|
||||
def test_sgd_dynamic_weightdecay():
|
||||
"""
|
||||
Feature: Test SGD optimizer.
|
||||
Description: Test if error is raised when weight decay is dynamic.
|
||||
Expectation: ValueError is raised.
|
||||
"""
|
||||
net = Net()
|
||||
params = net.trainable_params()
|
||||
group_params = [{'params': [params[0]], 'weight_decay': WeightDecaySchdule()}, {'params': [params[1]]}]
|
||||
|
||||
weight_decay_error = "For 'SGD', dynamic weight decay is currently not supported, the argument 'weight_decay' " \
|
||||
"or 'weight_decay' set in grouped 'params' must be float or int type."
|
||||
with pytest.raises(TypeError, match=weight_decay_error):
|
||||
nn.SGD(params, learning_rate=0.1, weight_decay=WeightDecaySchdule())
|
||||
|
||||
with pytest.raises(TypeError, match=weight_decay_error):
|
||||
nn.SGD(group_params, learning_rate=0.1)
|
Loading…
Reference in New Issue