modify lars interface

This commit is contained in:
Ziyan 2020-06-18 20:15:03 +08:00
parent c8f26f799b
commit 41ddc153a6
4 changed files with 47 additions and 52 deletions

View File

@ -13,22 +13,18 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""lars optimizer""" """lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.common import Tensor, Parameter, dtype as mstype
from .optimizer import _grad_scale, Optimizer from .optimizer import _grad_scale, Optimizer
_lars_opt = C.MultitypeFuncGraph("lars_opt") _lars_opt = C.MultitypeFuncGraph("lars_opt")
@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool") @_lars_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter.""" """Apply lars optimizer to the weight parameter."""
if lars_flag: if lars_flag:
op_reduce_sum = P.SquareSumAll() op_reduce_sum = P.SquareSumAll()
@ -42,10 +38,12 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f
return gradient return gradient
def _check_param_value(optimizer, epsilon, hyperpara, use_clip, prim_name): def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name):
validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) validator.check_value_type("optimizer", optimizer, Optimizer, prim_name)
if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name:
raise TypeError("LARS can not be used with ", optimizer.cls_name)
validator.check_value_type("epsilon", epsilon, [float], prim_name) validator.check_value_type("epsilon", epsilon, [float], prim_name)
validator.check_value_type("hyperpara", hyperpara, [float], prim_name) validator.check_value_type("coefficient", coefficient, [float], prim_name)
validator.check_value_type("use_clip", use_clip, [bool], prim_name) validator.check_value_type("use_clip", use_clip, [bool], prim_name)
class LARS(Optimizer): class LARS(Optimizer):
@ -58,14 +56,10 @@ class LARS(Optimizer):
Args: Args:
optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients. optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients.
epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05. epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05.
hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001. coefficient (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False. use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False.
decay_filter (Function): A function to determine whether apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
lars_filter (Function): A function to determine whether apply lars algorithm. Default: lars_filter (Function): A function to determine whether apply lars algorithm. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. Default: 1.0.
Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is - **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is
@ -78,51 +72,54 @@ class LARS(Optimizer):
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9) >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, hyperpara=0.02) >>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)
>>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)
""" """
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False,
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")], weight_decay, loss_scale) _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name)
if optimizer.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(optimizer, epsilon, hyperpara, use_clip, self.cls_name)
self.opt = optimizer self.opt = optimizer
self.parameters = optimizer.parameters self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.learning_rate = optimizer.learning_rate
self.lars = P.LARSUpdate(epsilon, hyperpara, use_clip)
self.reciprocal_scale = 1.0 / loss_scale
self.weight_decay = weight_decay
self.cast = P.Cast() self.cast = P.Cast()
self.decay_flag = tuple(decay_filter(x) for x in self.parameters) self.parameters = optimizer.parameters
if use_clip is True:
self.learning_rate = optimizer.learning_rate
self.dynamic_lr = optimizer.dynamic_lr
self.gather = optimizer.gather
self.assignadd = optimizer.assignadd
self.global_step = optimizer.global_step
else:
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.reciprocal_scale = optimizer.reciprocal_scale
optimizer.reciprocal_scale = 1.0
self.is_group = optimizer.is_group
if self.is_group:
self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay))
else:
self.weight_decay = optimizer.weight_decay / optimizer.loss_scale
optimizer.exec_weight_decay = False
optimizer.weight_decay = 0.0
self.decay_flags = optimizer.decay_flags
self.lars_flag = tuple(lars_filter(x) for x in self.parameters) self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.dynamic_lr = False
self.gather = None
self.global_step = None
self.axis = None
if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True
self.assignadd = P.AssignAdd()
self.gather = P.GatherV2()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="lars_global_step")
self.axis = 0
def construct(self, gradients): def construct(self, gradients):
params = self.parameters params = self.parameters
if self.dynamic_lr: if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis) lr = self.gather(self.learning_rate, self.global_step, 0)
F.control_depend(lr, self.assignadd(self.global_step, 1)) F.control_depend(lr, self.assignadd(self.global_step, 1))
else: else:
lr = self.learning_rate lr = self.learning_rate
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
if self.is_group:
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, self.weight_decay, lr), grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay,
gradients, params, self.decay_flag, self.lars_flag) gradients, params, self.decay_flags, self.lars_flag)
else:
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay),
gradients, params, self.decay_flags, self.lars_flag)
success = self.opt(grad_t) success = self.opt(grad_t)
return success return success

View File

@ -182,13 +182,11 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
{'order_params': net.trainable_params()}] {'order_params': net.trainable_params()}]
if config.use_lars: if config.use_lars:
momentum = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, momentum = nn.Momentum(group_params, lr, config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale,
use_nesterov=config.use_nesterov) use_nesterov=config.use_nesterov)
opt = nn.LARS(momentum, epsilon=config.lars_epsilon, hyperpara=config.lars_coefficient, opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
weight_decay=config.weight_decay, lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name,
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name,
loss_scale=config.loss_scale)
else: else:
opt = nn.Momentum(group_params, lr, config.momentum, opt = nn.Momentum(group_params, lr, config.momentum,

View File

@ -56,7 +56,7 @@ def test_lars_multi_step_lr():
lr = multisteplr(10, [2, 6]) lr = multisteplr(10, [2, 6])
SGD = Momentum(net.trainable_params(), lr, 0.9) SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, optimizer = LARS(SGD, epsilon=1e-08, coefficient=0.02, use_clip=True,
lars_filter=lambda x: 'bn' not in x.name) lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
@ -73,7 +73,7 @@ def test_lars_float_lr():
lr = 0.1 lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9) SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, optimizer = LARS(SGD, epsilon=1e-08, coefficient=0.02,
lars_filter=lambda x: 'bn' not in x.name) lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)

View File

@ -205,7 +205,7 @@ def test_lars():
lr = Tensor(np.ones([6]), dtype=ms.float32) lr = Tensor(np.ones([6]), dtype=ms.float32)
sgd = Momentum(net.trainable_params(), lr, 0.9) sgd = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(sgd, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, optimizer = LARS(sgd, epsilon=1e-08, coefficient=0.02,
lars_filter=lambda x: 'bn' not in x.name) lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = NetWithLoss(net, strategy3) net_with_loss = NetWithLoss(net, strategy3)
train_net = TrainOneStepCell(net_with_loss, optimizer) train_net = TrainOneStepCell(net_with_loss, optimizer)