forked from mindspore-Ecosystem/mindspore
modify lars interface
This commit is contained in:
parent
c8f26f799b
commit
41ddc153a6
|
@ -13,22 +13,18 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lars optimizer"""
|
||||
from typing import Iterable
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import Tensor, Parameter, dtype as mstype
|
||||
from .optimizer import _grad_scale, Optimizer
|
||||
|
||||
_lars_opt = C.MultitypeFuncGraph("lars_opt")
|
||||
|
||||
|
||||
@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
|
||||
@_lars_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag):
|
||||
"""Apply lars optimizer to the weight parameter."""
|
||||
if lars_flag:
|
||||
op_reduce_sum = P.SquareSumAll()
|
||||
|
@ -42,10 +38,12 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f
|
|||
|
||||
return gradient
|
||||
|
||||
def _check_param_value(optimizer, epsilon, hyperpara, use_clip, prim_name):
|
||||
def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name):
|
||||
validator.check_value_type("optimizer", optimizer, Optimizer, prim_name)
|
||||
if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name:
|
||||
raise TypeError("LARS can not be used with ", optimizer.cls_name)
|
||||
validator.check_value_type("epsilon", epsilon, [float], prim_name)
|
||||
validator.check_value_type("hyperpara", hyperpara, [float], prim_name)
|
||||
validator.check_value_type("coefficient", coefficient, [float], prim_name)
|
||||
validator.check_value_type("use_clip", use_clip, [bool], prim_name)
|
||||
|
||||
class LARS(Optimizer):
|
||||
|
@ -58,14 +56,10 @@ class LARS(Optimizer):
|
|||
Args:
|
||||
optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients.
|
||||
epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05.
|
||||
hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
|
||||
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
|
||||
coefficient (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
|
||||
use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False.
|
||||
decay_filter (Function): A function to determine whether apply weight decay on parameters. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
lars_filter (Function): A function to determine whether apply lars algorithm. Default:
|
||||
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
|
||||
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is
|
||||
|
@ -78,51 +72,54 @@ class LARS(Optimizer):
|
|||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, hyperpara=0.02)
|
||||
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False,
|
||||
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name,
|
||||
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0):
|
||||
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")], weight_decay, loss_scale)
|
||||
if optimizer.is_group:
|
||||
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
|
||||
_check_param_value(optimizer, epsilon, hyperpara, use_clip, self.cls_name)
|
||||
def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False,
|
||||
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
|
||||
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
|
||||
_check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name)
|
||||
self.opt = optimizer
|
||||
self.parameters = optimizer.parameters
|
||||
self.learning_rate = optimizer.learning_rate
|
||||
self.lars = P.LARSUpdate(epsilon, hyperpara, use_clip)
|
||||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
self.weight_decay = weight_decay
|
||||
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
|
||||
self.cast = P.Cast()
|
||||
self.decay_flag = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.parameters = optimizer.parameters
|
||||
if use_clip is True:
|
||||
self.learning_rate = optimizer.learning_rate
|
||||
self.dynamic_lr = optimizer.dynamic_lr
|
||||
self.gather = optimizer.gather
|
||||
self.assignadd = optimizer.assignadd
|
||||
self.global_step = optimizer.global_step
|
||||
else:
|
||||
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
|
||||
self.reciprocal_scale = optimizer.reciprocal_scale
|
||||
optimizer.reciprocal_scale = 1.0
|
||||
self.is_group = optimizer.is_group
|
||||
if self.is_group:
|
||||
self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay))
|
||||
else:
|
||||
self.weight_decay = optimizer.weight_decay / optimizer.loss_scale
|
||||
optimizer.exec_weight_decay = False
|
||||
optimizer.weight_decay = 0.0
|
||||
self.decay_flags = optimizer.decay_flags
|
||||
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.dynamic_lr = False
|
||||
self.gather = None
|
||||
self.global_step = None
|
||||
self.axis = None
|
||||
if isinstance(self.learning_rate.default_input, Iterable) or \
|
||||
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
|
||||
self.dynamic_lr = True
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.gather = P.GatherV2()
|
||||
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="lars_global_step")
|
||||
self.axis = 0
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
if self.dynamic_lr:
|
||||
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||
lr = self.gather(self.learning_rate, self.global_step, 0)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||
else:
|
||||
lr = self.learning_rate
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
|
||||
|
||||
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, self.weight_decay, lr),
|
||||
gradients, params, self.decay_flag, self.lars_flag)
|
||||
if self.is_group:
|
||||
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay,
|
||||
gradients, params, self.decay_flags, self.lars_flag)
|
||||
else:
|
||||
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay),
|
||||
gradients, params, self.decay_flags, self.lars_flag)
|
||||
success = self.opt(grad_t)
|
||||
|
||||
return success
|
||||
|
|
|
@ -182,13 +182,11 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
{'order_params': net.trainable_params()}]
|
||||
|
||||
if config.use_lars:
|
||||
momentum = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
momentum = nn.Momentum(group_params, lr, config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale,
|
||||
use_nesterov=config.use_nesterov)
|
||||
opt = nn.LARS(momentum, epsilon=config.lars_epsilon, hyperpara=config.lars_coefficient,
|
||||
weight_decay=config.weight_decay,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name,
|
||||
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name,
|
||||
loss_scale=config.loss_scale)
|
||||
opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
|
||||
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
|
||||
|
||||
else:
|
||||
opt = nn.Momentum(group_params, lr, config.momentum,
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_lars_multi_step_lr():
|
|||
|
||||
lr = multisteplr(10, [2, 6])
|
||||
SGD = Momentum(net.trainable_params(), lr, 0.9)
|
||||
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
|
||||
optimizer = LARS(SGD, epsilon=1e-08, coefficient=0.02, use_clip=True,
|
||||
lars_filter=lambda x: 'bn' not in x.name)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
|
@ -73,7 +73,7 @@ def test_lars_float_lr():
|
|||
|
||||
lr = 0.1
|
||||
SGD = Momentum(net.trainable_params(), lr, 0.9)
|
||||
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
|
||||
optimizer = LARS(SGD, epsilon=1e-08, coefficient=0.02,
|
||||
lars_filter=lambda x: 'bn' not in x.name)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
|
|
|
@ -205,7 +205,7 @@ def test_lars():
|
|||
|
||||
lr = Tensor(np.ones([6]), dtype=ms.float32)
|
||||
sgd = Momentum(net.trainable_params(), lr, 0.9)
|
||||
optimizer = LARS(sgd, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
|
||||
optimizer = LARS(sgd, epsilon=1e-08, coefficient=0.02,
|
||||
lars_filter=lambda x: 'bn' not in x.name)
|
||||
net_with_loss = NetWithLoss(net, strategy3)
|
||||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
|
|
Loading…
Reference in New Issue