!33333 Refactor optimizers to support flatten parameters

Merge pull request !33333 from hewei/flatten_weights
This commit is contained in:
i-robot 2022-04-21 12:14:30 +00:00 committed by Gitee
commit 5524b9d4e0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 110 additions and 82 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -181,12 +181,13 @@ class Adagrad(Optimizer):
update_slots=True, loss_scale=1.0, weight_decay=0.0):
super(Adagrad, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(accum, update_slots, self.cls_name)
self.accum = self.parameters.clone(prefix="accum", init=accum)
self.accum = self._parameters.clone(prefix="accum", init=accum)
self.opt = P.ApplyAdagrad(update_slots=update_slots)
def construct(self, grads):
params = self.parameters
params = self._parameters
accum = self.accum
grads = self.flatten_gradients(grads)
grads = self.decay_weight(grads)
grads = self.gradients_centralization(grads)
grads = self.scale_grad(grads)

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -345,15 +345,15 @@ class AdaFactor(Optimizer):
"""init adafactor variables"""
if beta1 > 0:
self.use_first_moment = True
self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
self.exp_avg = self._parameters.clone(prefix="exp_avg", init='zeros')
else:
self.use_first_moment = False
self.exp_avg = ParameterTuple([Parameter(Tensor(0.0))] * len(self.parameters))
self.exp_avg = ParameterTuple([Parameter(Tensor(0.0))] * len(self._parameters))
self.exp_avg_sq = []
self.exp_avg_sq_col = []
self.exp_avg_sq_row = []
for param in self.parameters:
for param in self._parameters:
param_dtype = param.dtype
param_shape = param.shape
param_name = param.name
@ -398,6 +398,7 @@ class AdaFactor(Optimizer):
return False
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
lr = self.get_lr()
step = F.assign_add(self.step, 1)
if self.scale_lr and self.relative_step:
@ -411,13 +412,13 @@ class AdaFactor(Optimizer):
if self.use_fused_ada_factor:
success = self.hyper_map(F.partial(_adafactor_opt, self.fused_ada_factor, self.eps, self.clip_threshold,
self.beta1, beta2t, self.weight_decay, lr),
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
gradients, self._parameters, self.exp_avg, self.exp_avg_sq_row,
self.exp_avg_sq_col, self.exp_avg_sq)
else:
success = self.hyper_map(F.partial(_adafactor_opt, self.eps, self.clip_threshold, self.beta1, beta2t,
self.weight_decay, self.scale_parameter, self.compression,
self.use_first_moment, self.weight_decay_flag, lr),
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
gradients, self._parameters, self.exp_avg, self.exp_avg_sq_row,
self.exp_avg_sq_col, self.exp_avg_sq)
return success

View File

@ -372,8 +372,8 @@ class Adam(Optimizer):
self.eps = Tensor(eps, mstype.float32)
self.use_nesterov = use_nesterov
self.use_locking = use_locking
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
self._is_device = True
self.opt = P.Adam(use_locking, use_nesterov)
@ -384,7 +384,7 @@ class Adam(Optimizer):
self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
def construct(self, gradients):
params = self.parameters
params = self._parameters
moment1 = self.moment1
moment2 = self.moment2
gradients = self.flatten_gradients(gradients)
@ -572,24 +572,25 @@ class AdamWeightDecay(Optimizer):
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.moments1 = self._parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self._parameters.clone(prefix="adam_v", init='zeros')
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
weight_decay = self.get_weight_decay()
lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, weight_decay, self.parameters, self.moments1,
lr, weight_decay, self._parameters, self.moments1,
self.moments2, gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
weight_decay, self.parameters, self.moments1, self.moments2,
weight_decay, self._parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, weight_decay),
self.parameters, self.moments1, self.moments2,
self._parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
@ -747,15 +748,16 @@ class AdamOffload(Optimizer):
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
self.eps = Tensor(eps, mstype.float32)
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
self.opt.add_prim_attr("primitive_target", "CPU")
def construct(self, gradients):
params = self.parameters
params = self._parameters
moment1 = self.moment1
moment2 = self.moment2
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -157,13 +157,13 @@ class ASGD(Optimizer):
self.alpha = alpha
self.t0 = Tensor([t0], dtype=mstype.float32)
mu, eta = [], []
for param in self.parameters:
for param in self._parameters:
mu.append(Parameter(Tensor(1., dtype=mstype.float32), name='%s%s' % ("mu_", param.name)))
eta.append(Parameter(Tensor(0., dtype=mstype.float32), name='%s%s' % ("eta_", param.name)))
self.lens = len(self.parameters)
self.lens = len(self._parameters)
self.mu = mindspore.ParameterTuple(mu)
self.eta = mindspore.ParameterTuple(eta)
self.ax = self.parameters.clone(prefix="ax_", init='zeros')
self.ax = self._parameters.clone(prefix="ax_", init='zeros')
self.pow = P.Pow()
self.maximum = P.Maximum()
self.assign = P.Assign()
@ -173,6 +173,7 @@ class ASGD(Optimizer):
self.squeeze = P.Squeeze()
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
@ -180,8 +181,8 @@ class ASGD(Optimizer):
if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
success = True
for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)):
params = self._parameters
for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, params, self.mu, self.eta, self.ax)):
lr = lrs[index] if self.is_group_lr else lrs
lr = self.squeeze(lr)

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -212,14 +212,14 @@ class FTRL(Optimizer):
f"in FTRL, they should all be false, but got dynamic learning rate {self.dynamic_lr} and"
f" group learning rate {self.is_group_lr}.")
_check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.moments = self._parameters.clone(prefix="moments", init=initial_accum)
self.linear = self._parameters.clone(prefix="linear", init='zeros')
self.l1 = l1
self.l2 = l2
self.lr = learning_rate
self.lr_power = lr_power
if not self.is_group:
self.decay_flags = tuple((lambda: True)() for x in self.parameters)
self.decay_flags = tuple((lambda: True)() for x in self._parameters)
self.opt = P.ApplyFtrl(use_locking=use_locking)
self.use_locking = use_locking
self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
@ -232,9 +232,10 @@ class FTRL(Optimizer):
self._ps_push.add_prim_attr("lr_power", lr_power)
def construct(self, grads):
params = self.parameters
params = self._parameters
moments = self.moments
linear = self.linear
grads = self.flatten_gradients(grads)
grads = self.decay_weight(grads)
grads = self.gradients_centralization(grads)
grads = self.scale_grad(grads)

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -347,7 +347,7 @@ class Lamb(Optimizer):
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.params = self.parameters
self.params = self._parameters
self.moments1 = self.params.clone(prefix="lamb_m", init='zeros')
self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
self.device_ascend = context.get_context("device_target") == "Ascend"
@ -358,6 +358,7 @@ class Lamb(Optimizer):
if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
gradients = self.flatten_gradients(gradients)
gradients = self.gradients_centralization(gradients)
if self.is_group:
if self.is_group_lr:

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -122,9 +122,13 @@ class LARS(Optimizer):
self.weight_decay = optimizer.weight_decay
self.global_step = optimizer.global_step
self.parameters = optimizer.parameters
self._user_parameters += [param.name for param in self.parameters]
self._parameters = optimizer._parameters # pylint: disable=W0212
self._use_flattened_params = optimizer._use_flattened_params # pylint: disable=W0212
if self._use_flattened_params:
self.opt._use_flattened_params = False # pylint: disable=W0212
self._user_parameters += [param.name for param in self._parameters]
self.use_clip = use_clip
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
self.lars_flag = tuple(lars_filter(x) for x in self._parameters)
self.is_group = optimizer.is_group
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.decay_flags = optimizer.decay_flags
@ -166,7 +170,8 @@ class LARS(Optimizer):
return lr
def construct(self, gradients):
params = self.parameters
params = self._parameters
gradients = self.flatten_gradients(gradients)
if self.use_clip:
lr = self._get_lr()
else:

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -267,8 +267,8 @@ class LazyAdam(Optimizer):
self.use_nesterov = use_nesterov
self.use_locking = use_locking
self._is_device = True
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
@ -277,6 +277,7 @@ class LazyAdam(Optimizer):
self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
@ -292,14 +293,14 @@ class LazyAdam(Optimizer):
success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
lr, gradients, self._parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
else:
success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps,
lr),
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
gradients, self._parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
return success

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -175,7 +175,7 @@ class Momentum(Optimizer):
raise ValueError("For 'Momentum', the argument 'momentum' should be at least 0.0, "
"but got {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.params = self._parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
@ -183,6 +183,7 @@ class Momentum(Optimizer):
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)

View File

@ -188,6 +188,7 @@ class Optimizer(Cell):
if self.is_group:
self.parameters = ParameterTuple(self.group_params)
self._parameters = self.parameters
decay_filter = lambda x: isinstance(x, Cell) or x > 0
dynamic_decay_filter = lambda x: isinstance(x, Cell)
self.decay_flags = tuple(decay_filter(x) for x in self.group_weight_decay)
@ -197,29 +198,33 @@ class Optimizer(Cell):
self.exec_weight_decay = any(self.decay_flags)
self.grad_centralization_flags = tuple(self.group_grad_centralization)
else:
parameters = self._get_flattened_params(parameters)
self.parameters = ParameterTuple(parameters)
flat_params = self._get_flattened_params(parameters)
if self._use_flattened_params:
self._parameters = ParameterTuple(flat_params)
else:
self._parameters = self.parameters
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
self.dynamic_decay_flags = isinstance(weight_decay, Cell)
self.exec_weight_decay = isinstance(weight_decay, Cell) or weight_decay > 0
self.weight_decay = Tensor(weight_decay, mstype.float32) if not self.dynamic_decay_flags else weight_decay
# when a parameter has been unique, there is no need do another unique in optimizer.
for param in self.parameters:
for param in self._parameters:
if param.unique:
self._unique = False
break
# set user's parameters as local parameters
for param in self.parameters:
for param in self._parameters:
self._user_parameters.append(param.name)
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.ps_parameters = tuple(ps_filter(x) for x in self._parameters)
cache_filter = lambda x: x.cache_enable
self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
self.cache_enable = tuple(cache_filter(x) for x in self._parameters)
self.reciprocal_scale = Tensor(1.0 / self.loss_scale, mstype.float32)
self.need_scale = self.loss_scale != 1.0
self.global_step_increase_tensor = Tensor(1, mstype.int32)
self.param_length = len(self.parameters)
self.param_length = len(self._parameters)
self.map_ = C.Map()
self.map_reverse = C.Map(None, True)
self.hyper_map = C.HyperMap()
@ -273,7 +278,7 @@ class Optimizer(Cell):
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
for param in self._parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
@ -366,7 +371,7 @@ class Optimizer(Cell):
tuple[Tensor], The gradients after weight decay.
"""
if self.exec_weight_decay:
params = self.parameters
params = self._parameters
weight_decay = self.get_weight_decay()
if self.is_group:
gradients = self.map_(F.partial(_apply_decay), weight_decay, self.decay_flags, params, gradients)
@ -722,7 +727,7 @@ class Optimizer(Cell):
f"but got {type(param)}.")
lr = []
ids = [id(p) for p in self.parameters]
ids = [id(p) for p in self._parameters]
for p in param_list:
validator.check_value_type("parameter", p, [Parameter], self.cls_name)
if id(p) not in ids:
@ -777,7 +782,7 @@ class Optimizer(Cell):
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self._parameters[i],)
new_param_group = []
for root in range(self.dev_num):
if root > 0:

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -188,7 +188,7 @@ class ProximalAdagrad(Optimizer):
use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(ProximalAdagrad, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param_value(accum, l1, l2, use_locking, self.cls_name)
self.accum = self.parameters.clone(prefix="accum", init=accum)
self.accum = self._parameters.clone(prefix="accum", init=accum)
self.l1 = Tensor(l1, mstype.float32)
self.l2 = Tensor(l2, mstype.float32)
self.use_locking = use_locking
@ -196,8 +196,9 @@ class ProximalAdagrad(Optimizer):
self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
def construct(self, grads):
params = self.parameters
params = self._parameters
accum = self.accum
grads = self.flatten_gradients(grads)
grads = self.decay_weight(grads)
grads = self.gradients_centralization(grads)
grads = self.scale_grad(grads)

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -211,18 +211,19 @@ class RMSProp(Optimizer):
self.centered = centered
if centered:
self.opt = P.ApplyCenteredRMSProp(use_locking)
self.mg = self.parameters.clone(prefix="mean_grad", init='zeros')
self.mg = self._parameters.clone(prefix="mean_grad", init='zeros')
else:
self.opt = P.ApplyRMSProp(use_locking)
self.momentum = momentum
self.ms = self.parameters.clone(prefix="mean_square", init='ones')
self.moment = self.parameters.clone(prefix="moment", init='zeros')
self.ms = self._parameters.clone(prefix="mean_square", init='ones')
self.moment = self._parameters.clone(prefix="moment", init='zeros')
self.epsilon = epsilon
self.decay = decay
def construct(self, gradients):
params = self.parameters
params = self._parameters
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -176,8 +176,8 @@ class Rprop(Optimizer):
self.etaminus, self.etaplus = etas
self.step_size_min, self.step_size_max = step_sizes
self.prev = self.parameters.clone(prefix="prev", init='zeros')
self.step_size = self.parameters.clone(prefix="step_size", init='zeros')
self.prev = self._parameters.clone(prefix="prev", init='zeros')
self.step_size = self._parameters.clone(prefix="step_size", init='zeros')
self.fill = P.Fill()
self.sign = P.Sign()
@ -188,6 +188,7 @@ class Rprop(Optimizer):
self.ones_like = P.OnesLike()
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
@ -196,7 +197,7 @@ class Rprop(Optimizer):
self.assignadd(self.global_step, self.global_step_increase_tensor)
success = True
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters,
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self._parameters,
self.prev, self.step_size)):
lr = lrs[index] if self.is_group_lr else lrs

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -180,13 +180,14 @@ class SGD(Optimizer):
self.opt = P.SGD(dampening, weight_decay, nesterov)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.accum = self.parameters.clone(prefix="accum", init='zeros')
self.stat = self.parameters.clone(prefix="stat", init='ones')
self.accum = self._parameters.clone(prefix="accum", init='zeros')
self.stat = self._parameters.clone(prefix="stat", init='ones')
def construct(self, gradients):
params = self.parameters
params = self._parameters
accum = self.accum
stat = self.stat
gradients = self.flatten_gradients(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -379,7 +379,7 @@ class ThorGpu(Optimizer):
super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.params = self._parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
@ -408,7 +408,7 @@ class ThorGpu(Optimizer):
self.matrix_a = ParameterTuple(self.matrix_a)
self.matrix_g = ParameterTuple(self.matrix_g)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim)
self.enable_clip_grad = enable_clip_grad
self.frequency = frequency
@ -586,6 +586,7 @@ class ThorGpu(Optimizer):
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.flatten_gradients(gradients)
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
@ -667,7 +668,7 @@ class ThorAscend(Optimizer):
super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.params = self._parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
@ -704,7 +705,7 @@ class ThorAscend(Optimizer):
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.thor = True
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.decay_flags = tuple(decay_filter(x) for x in self._parameters)
self.damping = damping
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
@ -1246,6 +1247,7 @@ class ThorAscend(Optimizer):
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.flatten_gradients(gradients)
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)

View File

@ -118,6 +118,7 @@ def test_not_flattened_params():
assert not opt._use_flattened_params # pylint: disable=W0212
assert len(opt.parameters) == 3
assert len(opt.cache_enable) == 3
assert id(opt.parameters) == id(opt._parameters) # pylint: disable=W0212
def test_with_flattened_params():
@ -133,16 +134,18 @@ def test_with_flattened_params():
Tensor._flatten_tensors(paras) # pylint: disable=W0212
opt = Optimizer(0.1, paras)
assert opt._use_flattened_params # pylint: disable=W0212
assert len(opt.parameters) == 1
assert len(opt.parameters) == 3
assert len(opt._parameters) == 1 # pylint: disable=W0212
assert len(opt.cache_enable) == 1
assert opt.parameters[0].dtype == ms.float32
assert opt.parameters[0].shape == [3]
assert opt.parameters[0]._size == 3 # pylint: disable=W0212
assert np.allclose(opt.parameters[0].asnumpy(), np.array([1, 2, 3]))
flat_param = opt._parameters[0] # pylint: disable=W0212
assert flat_param.dtype == ms.float32
assert flat_param.shape == [3]
assert flat_param._size == 3 # pylint: disable=W0212
assert np.allclose(flat_param.asnumpy(), np.array([1, 2, 3]))
p1.asnumpy()[0] = 6
p2.asnumpy()[0] = 6
p3.asnumpy()[0] = 6
assert np.allclose(opt.parameters[0].asnumpy(), np.array([6, 6, 6]))
assert np.allclose(flat_param.asnumpy(), np.array([6, 6, 6]))
def test_adam_with_flattened_params():