forked from mindspore-Ecosystem/mindspore
add dynamic_decay
This commit is contained in:
parent
156e8d2e49
commit
a124ec4de7
|
@ -87,7 +87,11 @@ class Adagrad(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -120,8 +124,14 @@ class Adagrad(Optimizer):
|
|||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
|
||||
Default: 1.0.
|
||||
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
|
||||
Default: 0.0.
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
|
||||
|
|
|
@ -32,7 +32,7 @@ _scaler_one = Tensor(1, mstype.int32)
|
|||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
|
@ -231,7 +231,11 @@ class Adam(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -269,7 +273,16 @@ class Adam(Optimizer):
|
|||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If true, update the gradients using NAG.
|
||||
If false, update the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
|
||||
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
||||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
|
@ -427,7 +440,11 @@ class AdamWeightDecay(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
|
||||
parameters that appeared in the network to improve performance. The value should be parameters whose
|
||||
|
@ -455,7 +472,15 @@ class AdamWeightDecay(Optimizer):
|
|||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -506,23 +531,24 @@ class AdamWeightDecay(Optimizer):
|
|||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
|
||||
def construct(self, gradients):
|
||||
weight_decay = self.get_weight_decay()
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1,
|
||||
lr, weight_decay, self.parameters, self.moments1,
|
||||
self.moments2, gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay),
|
||||
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, weight_decay),
|
||||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
|
||||
return optim_result
|
||||
|
||||
|
||||
|
@ -569,7 +595,11 @@ class AdamOffload(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
|
||||
parameters that appeared in the network to improve performance. The value should be parameters whose
|
||||
|
@ -603,7 +633,16 @@ class AdamOffload(Optimizer):
|
|||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If true, update the gradients using NAG.
|
||||
If false, update the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
|
||||
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
||||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""asgd"""
|
||||
from mindspore.ops import functional as F, operations as P
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -62,7 +62,11 @@ class ASGD(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -91,7 +95,14 @@ class ASGD(Optimizer):
|
|||
lambd (float): The decay term. Default: 1e-4.
|
||||
alpha (float): The power for eta update. Default: 0.75.
|
||||
t0 (float): The point of starting averaging. Default: 1e6.
|
||||
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -152,7 +163,6 @@ class ASGD(Optimizer):
|
|||
self.lens = len(self.parameters)
|
||||
self.mu = mindspore.ParameterTuple(mu)
|
||||
self.eta = mindspore.ParameterTuple(eta)
|
||||
self.step = Parameter(Tensor(1., dtype=mstype.float32), name='step')
|
||||
self.ax = self.parameters.clone(prefix="ax_", init='zeros')
|
||||
self.pow = P.Pow()
|
||||
self.maximum = P.Maximum()
|
||||
|
@ -173,7 +183,7 @@ class ASGD(Optimizer):
|
|||
lr = lrs[index] if self.is_group_lr else lrs
|
||||
lr = self.squeeze(lr)
|
||||
|
||||
if self.step == 1.:
|
||||
if self.global_step == 1:
|
||||
self.assign(eta, lr)
|
||||
|
||||
param_fp32 = self.cast(param, mstype.float32)
|
||||
|
@ -188,8 +198,7 @@ class ASGD(Optimizer):
|
|||
else:
|
||||
self.assign(ax, param)
|
||||
|
||||
self.assign(eta, lr / (self.pow((1. + (self.lambd * lr * self.step)), self.alpha)))
|
||||
self.assign(mu, 1. / self.squeeze(self.maximum(1., self.step - self.t0)))
|
||||
|
||||
success = F.depend(success, self.assignadd(self.step, 1.))
|
||||
self.assign(eta, lr / (self.pow((1. + (self.lambd * lr * self.cast(self.global_step, mstype.float32))),
|
||||
self.alpha)))
|
||||
self.assign(mu, 1. / self.squeeze(self.maximum(1., self.cast(self.global_step, mstype.float32) - self.t0)))
|
||||
return success
|
||||
|
|
|
@ -122,7 +122,11 @@ class FTRL(Optimizer):
|
|||
- lr: Using different learning rate by grouping parameters is currently not supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -147,8 +151,14 @@ class FTRL(Optimizer):
|
|||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
|
||||
Default: 1.0.
|
||||
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
|
||||
Default: 0.0.
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
|
||||
|
|
|
@ -16,11 +16,9 @@
|
|||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
@ -34,7 +32,7 @@ num_one = Tensor(np.ones([1]), mstype.float32)
|
|||
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
||||
|
||||
|
||||
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||
"""
|
||||
|
@ -82,9 +80,9 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
|
|||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
|
||||
|
||||
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
||||
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
|
||||
- op_pow(beta1, op_cast(global_step, mstype.float32)))
|
||||
next_vv = next_v / (op_cast(num_one, mstype.float32) -
|
||||
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
|
||||
op_pow(beta2, op_cast(global_step, mstype.float32)))
|
||||
w_norm = op_norm(param_fp32)
|
||||
g_norm = op_norm(gradient_fp32)
|
||||
|
||||
|
@ -116,7 +114,7 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
|
|||
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
|
||||
|
||||
|
||||
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
|
||||
optim_filter):
|
||||
|
@ -148,7 +146,7 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para
|
|||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
new_global_step = op_cast(global_step + num_one, mstype.float32)
|
||||
new_global_step = op_cast(global_step, mstype.float32)
|
||||
weight_decay_flag = op_cast(decay_flag, mstype.float32)
|
||||
|
||||
update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32,
|
||||
|
@ -219,7 +217,11 @@ class Lamb(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -251,7 +253,15 @@ class Lamb(Optimizer):
|
|||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -310,13 +320,10 @@ class Lamb(Optimizer):
|
|||
self.params = self.parameters
|
||||
self.moments1 = self.params.clone(prefix="lamb_m", init='zeros')
|
||||
self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
|
||||
|
||||
if not self.dynamic_lr:
|
||||
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.device_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
def construct(self, gradients):
|
||||
weight_decay = self.get_weight_decay()
|
||||
lr = self.get_lr()
|
||||
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
|
@ -324,23 +331,20 @@ class Lamb(Optimizer):
|
|||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step),
|
||||
lr, self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
lr, weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr),
|
||||
self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr, self.weight_decay),
|
||||
self.global_step, lr, weight_decay),
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flags, self.optim_filter)
|
||||
|
||||
if self.use_parallel:
|
||||
optim_result = F.depend(optim_result, self.broadcast_params(optim_result))
|
||||
|
||||
if not self.dynamic_lr:
|
||||
optim_result = F.depend(optim_result, self.assignadd(self.global_step, 1))
|
||||
|
||||
return optim_result
|
||||
|
|
|
@ -24,14 +24,14 @@ from .optimizer import opt_init_args_register
|
|||
_lars_opt = C.MultitypeFuncGraph("lars_opt")
|
||||
|
||||
|
||||
@_lars_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag):
|
||||
@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _tensor_run_opt(lars, loss_scale, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag):
|
||||
"""Apply lars optimizer to the weight parameter."""
|
||||
if lars_flag:
|
||||
op_reduce_sum = P.SquareSumAll()
|
||||
w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient)
|
||||
if decay_flag:
|
||||
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
|
||||
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay / loss_scale, learning_rate)
|
||||
else:
|
||||
num_zero = 0.0
|
||||
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate)
|
||||
|
@ -109,6 +109,10 @@ class LARS(Optimizer):
|
|||
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
|
||||
_check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name)
|
||||
self.opt = optimizer
|
||||
self.dynamic_decay_flags = optimizer.dynamic_decay_flags
|
||||
self.dynamic_weight_decay = optimizer.dynamic_weight_decay
|
||||
self.weight_decay = optimizer.weight_decay
|
||||
self.global_step = optimizer.global_step
|
||||
self.parameters = optimizer.parameters
|
||||
self.use_clip = use_clip
|
||||
self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
|
||||
|
@ -119,24 +123,22 @@ class LARS(Optimizer):
|
|||
self.need_scale = optimizer.need_scale
|
||||
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
|
||||
self.cast = P.Cast()
|
||||
self.loss_scale = optimizer.loss_scale
|
||||
|
||||
if use_clip:
|
||||
self.is_group_lr = optimizer.is_group_lr
|
||||
self.dynamic_lr = optimizer.dynamic_lr
|
||||
self.origin_learning_rate = optimizer.learning_rate
|
||||
self.global_step = optimizer.global_step
|
||||
if self.is_group_lr and self.dynamic_lr:
|
||||
raise ValueError("For 'LARS', if the argument 'use_clip' is set to True, then the dynamic "
|
||||
"learning rate and group learning rate cannot both be true.")
|
||||
|
||||
if self.is_group:
|
||||
self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay))
|
||||
optimizer.weight_decay = tuple(map(lambda x: 0.0, optimizer.weight_decay))
|
||||
optimizer.dynamic_decay_flags = tuple(map(lambda x: False, self.dynamic_decay_flags))
|
||||
else:
|
||||
self.weight_decay = optimizer.weight_decay / optimizer.loss_scale
|
||||
optimizer.weight_decay = 0.0
|
||||
|
||||
optimizer.dynamic_decay_flags = False
|
||||
optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags))
|
||||
optimizer.dynamic_weight_decay = False
|
||||
optimizer.reciprocal_scale = 1.0
|
||||
optimizer.exec_weight_decay = False
|
||||
|
||||
|
@ -160,20 +162,20 @@ class LARS(Optimizer):
|
|||
lr = self._get_lr()
|
||||
else:
|
||||
lr = self.learning_rate
|
||||
weight_decay = self.get_weight_decay()
|
||||
|
||||
if self.need_scale:
|
||||
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
|
||||
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars), lr, self.weight_decay,
|
||||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale), lr, weight_decay,
|
||||
gradients, params, self.decay_flags, self.lars_flag)
|
||||
else:
|
||||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay,
|
||||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr), weight_decay,
|
||||
gradients, params, self.decay_flags, self.lars_flag)
|
||||
else:
|
||||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay),
|
||||
gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr, weight_decay),
|
||||
gradients, params, self.decay_flags, self.lars_flag)
|
||||
success = self.opt(gradients)
|
||||
|
||||
return success
|
||||
|
|
|
@ -150,7 +150,11 @@ class LazyAdam(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -188,7 +192,16 @@ class LazyAdam(Optimizer):
|
|||
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
|
||||
If true, update the gradients using NAG.
|
||||
If false, update the gradients without using NAG. Default: False.
|
||||
weight_decay (Union[float, int]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. In general,
|
||||
use the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update`
|
||||
in `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
|
|
|
@ -77,7 +77,11 @@ class Momentum(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -105,7 +109,16 @@ class Momentum(Optimizer):
|
|||
|
||||
momentum (float): Hyperparameter of type float, means momentum for the moving average.
|
||||
It must be at least 0.0.
|
||||
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
|
||||
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. In general, use the
|
||||
default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
||||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
|
|
|
@ -140,15 +140,15 @@ class Optimizer(Cell):
|
|||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
|
||||
self.loss_scale = loss_scale
|
||||
|
||||
self.dynamic_weight_decay = False
|
||||
weight_decay = self._preprocess_weight_decay(weight_decay)
|
||||
self.grad_centralization = False
|
||||
|
||||
self._unique = True
|
||||
self._target = context.get_context("device_target")
|
||||
self.dynamic_lr = False
|
||||
self.assignadd = None
|
||||
self.global_step = None
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
|
||||
self.is_group = False
|
||||
self.is_group_lr = False
|
||||
self.is_group_params_ordered = False
|
||||
|
@ -161,11 +161,10 @@ class Optimizer(Cell):
|
|||
self.group_grad_centralization = []
|
||||
self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
|
||||
|
||||
# The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
|
||||
if self.dynamic_lr:
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
|
||||
self._init_opt_attrs(learning_rate, parameters, weight_decay)
|
||||
|
||||
def _init_opt_attrs(self, learning_rate, parameters, weight_decay):
|
||||
"""initialize optimizer attributions"""
|
||||
if self.is_group_lr:
|
||||
self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
|
||||
else ParameterTuple(self.group_lr)
|
||||
|
@ -174,19 +173,21 @@ class Optimizer(Cell):
|
|||
|
||||
if self.is_group:
|
||||
self.parameters = ParameterTuple(self.group_params)
|
||||
self.weight_decay = tuple(self.group_weight_decay)
|
||||
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
|
||||
decay_filter = lambda x: x > 0
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
|
||||
decay_filter = lambda x: isinstance(x, Cell) or x > 0
|
||||
dynamic_decay_filter = lambda x: isinstance(x, Cell)
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.group_weight_decay)
|
||||
self.dynamic_decay_flags = tuple(dynamic_decay_filter(x) for x in self.group_weight_decay)
|
||||
self.weight_decay = tuple(x if flag else Tensor(x, mstype.float32)
|
||||
for x, flag in zip(self.group_weight_decay, self.dynamic_decay_flags))
|
||||
self.exec_weight_decay = any(self.decay_flags)
|
||||
self.grad_centralization_flags = tuple(self.group_grad_centralization)
|
||||
else:
|
||||
self.parameters = ParameterTuple(parameters)
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
|
||||
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.exec_weight_decay = self.weight_decay > 0
|
||||
self.dynamic_decay_flags = isinstance(weight_decay, Cell)
|
||||
self.exec_weight_decay = isinstance(weight_decay, Cell) or weight_decay > 0
|
||||
self.weight_decay = Tensor(weight_decay, mstype.float32) if not self.dynamic_decay_flags else weight_decay
|
||||
# when a parameter has been unique, there is no need do another unique in optimizer.
|
||||
for param in self.parameters:
|
||||
if param.unique:
|
||||
|
@ -196,8 +197,8 @@ class Optimizer(Cell):
|
|||
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
|
||||
cache_filter = lambda x: x.cache_enable
|
||||
self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
|
||||
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
|
||||
self.need_scale = loss_scale != 1.0
|
||||
self.reciprocal_scale = Tensor(1.0 / self.loss_scale, mstype.float32)
|
||||
self.need_scale = self.loss_scale != 1.0
|
||||
self.global_step_increase_tensor = Tensor(1, mstype.int32)
|
||||
self.param_length = len(self.parameters)
|
||||
self.map_ = C.Map()
|
||||
|
@ -316,12 +317,11 @@ class Optimizer(Cell):
|
|||
"""
|
||||
if self.exec_weight_decay:
|
||||
params = self.parameters
|
||||
weight_decay = self.get_weight_decay()
|
||||
if self.is_group:
|
||||
gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.map_(F.partial(_apply_decay), weight_decay, self.decay_flags, params, gradients)
|
||||
else:
|
||||
gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.map_(F.partial(_apply_decay, weight_decay), self.decay_flags, params, gradients)
|
||||
|
||||
return gradients
|
||||
|
||||
|
@ -370,12 +370,19 @@ class Optimizer(Cell):
|
|||
return gradients
|
||||
|
||||
def _preprocess_weight_decay(self, weight_decay):
|
||||
"""Check weight decay, and convert int to float."""
|
||||
"""preprocess weight decay"""
|
||||
if isinstance(weight_decay, (float, int)):
|
||||
weight_decay = float(weight_decay)
|
||||
validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
|
||||
return weight_decay
|
||||
raise TypeError("Weight decay should be int or float.")
|
||||
weight_decay = weight_decay * self.loss_scale
|
||||
elif isinstance(weight_decay, Cell):
|
||||
self.dynamic_weight_decay = True
|
||||
weight_decay = _WrappedWeightDecay(weight_decay, self.loss_scale)
|
||||
elif isinstance(weight_decay, Tensor):
|
||||
weight_decay = weight_decay
|
||||
else:
|
||||
raise TypeError("Weight decay should be int, float or Cell.")
|
||||
return weight_decay
|
||||
|
||||
def _preprocess_grad_centralization(self, grad_centralization):
|
||||
if not isinstance(grad_centralization, bool):
|
||||
|
@ -513,10 +520,9 @@ class Optimizer(Cell):
|
|||
lr = default_lr
|
||||
|
||||
if 'weight_decay' in group_param.keys():
|
||||
cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
|
||||
weight_decay_ = cur_weight_decay * self.loss_scale
|
||||
weight_decay_ = self._preprocess_weight_decay(group_param['weight_decay'])
|
||||
else:
|
||||
weight_decay_ = weight_decay * self.loss_scale
|
||||
weight_decay_ = self._preprocess_weight_decay(weight_decay)
|
||||
|
||||
if 'grad_centralization' in group_param.keys():
|
||||
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
|
||||
|
@ -575,6 +581,27 @@ class Optimizer(Cell):
|
|||
self.group_weight_decay = ordered_weight_decay
|
||||
self.group_grad_centralization = ordered_grad_centralization
|
||||
|
||||
|
||||
def get_weight_decay(self):
|
||||
"""
|
||||
The optimizer calls this interface to get the weight decay value for the current step.
|
||||
User-defined optimizers based on :class:`mindspore.nn.Optimizer` can also call this interface
|
||||
before updating the parameters.
|
||||
|
||||
Returns:
|
||||
float, the weight decay value of current step.
|
||||
"""
|
||||
if self.dynamic_weight_decay:
|
||||
if self.is_group:
|
||||
weight_decay = ()
|
||||
for weight_decay_, flag_ in zip(self.weight_decay, self.dynamic_decay_flags):
|
||||
current_weight_decay = weight_decay_(self.global_step) if flag_ else weight_decay_
|
||||
weight_decay += (current_weight_decay,)
|
||||
return weight_decay
|
||||
return self.weight_decay(self.global_step)
|
||||
return self.weight_decay
|
||||
|
||||
|
||||
def get_lr(self):
|
||||
"""
|
||||
The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
|
||||
|
@ -592,10 +619,10 @@ class Optimizer(Cell):
|
|||
lr += (current_dynamic_lr,)
|
||||
else:
|
||||
lr = self.learning_rate(self.global_step)
|
||||
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
return lr
|
||||
|
||||
|
||||
def get_lr_parameter(self, param):
|
||||
"""
|
||||
When parameters is grouped and learning rate is different for each group. Get the learning rate of the specified
|
||||
|
@ -835,3 +862,14 @@ class _IteratorLearningRate(LearningRateSchedule):
|
|||
|
||||
def construct(self, global_step):
|
||||
return self.gather(self.learning_rate, global_step, 0)
|
||||
|
||||
|
||||
class _WrappedWeightDecay(Cell):
|
||||
"""Inner api, a combination of dynamic or non-dynamic weight decay"""
|
||||
def __init__(self, weight_decay, loss_scale=1.0):
|
||||
super(_WrappedWeightDecay, self).__init__()
|
||||
self.weight_decay = weight_decay
|
||||
self.loss_scale = Tensor(loss_scale, mstype.float32)
|
||||
|
||||
def construct(self, global_step):
|
||||
return self.weight_decay(global_step) * self.loss_scale
|
||||
|
|
|
@ -93,7 +93,11 @@ class ProximalAdagrad(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -128,8 +132,14 @@ class ProximalAdagrad(Optimizer):
|
|||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
|
||||
Default: 1.0.
|
||||
weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
|
||||
Default: 0.0.
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
|
||||
|
|
|
@ -101,7 +101,11 @@ class RMSProp(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -140,7 +144,14 @@ class RMSProp(Optimizer):
|
|||
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
||||
`FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
|
||||
Default: 1.0.
|
||||
weight_decay (Union[float, int]): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0.
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
|
|
@ -14,10 +14,8 @@
|
|||
# ============================================================================
|
||||
"""rprop"""
|
||||
from mindspore import ops
|
||||
from mindspore.ops import functional as F, operations as P
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
|
@ -63,7 +61,11 @@ class Rprop(Optimizer):
|
|||
If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used.
|
||||
will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
|
||||
decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
|
||||
weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
|
||||
with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
|
||||
to get the weight decay value of current step.
|
||||
|
||||
- grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
|
||||
will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
|
||||
|
@ -92,7 +94,14 @@ class Rprop(Optimizer):
|
|||
etas (tuple[float, float]): The factor of multiplicative increasing or
|
||||
descreasing(etaminus, etaplus).
|
||||
step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
|
||||
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
||||
|
||||
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
||||
|
||||
- int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
|
||||
|
||||
- Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
|
||||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -169,7 +178,6 @@ class Rprop(Optimizer):
|
|||
self.step_size_min, self.step_size_max = step_sizes
|
||||
self.prev = self.parameters.clone(prefix="prev", init='zeros')
|
||||
self.step_size = self.parameters.clone(prefix="step_size", init='zeros')
|
||||
self.step = Parameter(Tensor(0., dtype=mstype.float32), name='step')
|
||||
|
||||
self.fill = P.Fill()
|
||||
self.sign = P.Sign()
|
||||
|
@ -190,7 +198,7 @@ class Rprop(Optimizer):
|
|||
self.prev, self.step_size)):
|
||||
lr = lrs[index] if self.is_group_lr else lrs
|
||||
|
||||
if self.step == 0.:
|
||||
if self.global_step == 1:
|
||||
step_size_fp32 = self.ones_like(step_size) * lr
|
||||
else:
|
||||
step_size_fp32 = self.cast(step_size, mstype.float32)
|
||||
|
@ -213,6 +221,4 @@ class Rprop(Optimizer):
|
|||
self.assign(prev, self.cast(gradient_update, prev.dtype))
|
||||
self.assign(step_size, self.cast(step_size_fp32, step_size.dtype))
|
||||
|
||||
success = F.depend(success, self.assignadd(self.step, 1.))
|
||||
|
||||
return success
|
||||
|
|
|
@ -138,7 +138,7 @@ def test_ascend_not_cell_dump():
|
|||
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
|
||||
|
||||
# make sure set_dump is ignored and all cell layer are dumped
|
||||
assert len(os.listdir(dump_file_path)) == 10
|
||||
assert len(os.listdir(dump_file_path)) == 11
|
||||
del os.environ['MINDSPORE_DUMP_CONFIG']
|
||||
|
||||
|
||||
|
|
|
@ -229,7 +229,7 @@ def test_bert_performance():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [11.324663, 11.283459, 11.283258]
|
||||
expect_loss_value = [11.325571, 11.284833, 11.284736]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from .weight_decay_utils import dynamic_weight_decay_cmp, WeightDecaySchdule, Net
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
|
||||
weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
|
||||
weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.Momentum(params1, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(params2, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for AdamWeightDecay
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for AdamWeightDecay
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.AdamWeightDecay(params1, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(params2, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
|
||||
def test_lamb_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Lamb
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Lamb(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Lamb(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lamb_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Lamb
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Lamb(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Lamb(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lamb_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.Lamb(params1, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Lamb(params2, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lars_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Lars
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
|
||||
opt1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
opt2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
optimizer1 = nn.LARS(opt1, lars_filter=lambda x: 'LayerNorm' not in x.name)
|
||||
optimizer2 = nn.LARS(opt2, lars_filter=lambda x: 'LayerNorm' not in x.name)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lars_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Lars
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
|
||||
opt1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
opt2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
optimizer1 = nn.LARS(opt1, lars_filter=lambda x: 'LayerNorm' not in x.name)
|
||||
optimizer2 = nn.LARS(opt2, lars_filter=lambda x: 'LayerNorm' not in x.name)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lars_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
opt1 = nn.Momentum(params1, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
opt2 = nn.Momentum(params2, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer1 = nn.LARS(opt1, lars_filter=lambda x: 'LayerNorm' not in x.name)
|
||||
optimizer2 = nn.LARS(opt2, lars_filter=lambda x: 'LayerNorm' not in x.name)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from .weight_decay_utils import dynamic_weight_decay_cmp, WeightDecaySchdule, Net
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
|
||||
weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
|
||||
weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.Momentum(params1, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(params2, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for AdamWeightDecay
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for AdamWeightDecay
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.AdamWeightDecay(params1, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(params2, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from .weight_decay_utils import dynamic_weight_decay_cmp, WeightDecaySchdule, Net
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
|
||||
weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Momentum(net1.trainable_params(), momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(net2.trainable_params(), momentum=0.001, learning_rate=0.001,
|
||||
weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_momentum_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.Momentum(params1, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Momentum(params2, momentum=0.001, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for AdamWeightDecay
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for AdamWeightDecay
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.AdamWeightDecay(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_adamweightdecay_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.AdamWeightDecay(params1, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.AdamWeightDecay(params2, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lamb_dynamic_weight_decay_pynative():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Lamb
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Lamb(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Lamb(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lamb_dynamic_weight_decay_graph():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Lamb
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net1, net2 = Net(), Net()
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
optimizer1 = nn.Lamb(net1.trainable_params(), learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Lamb(net2.trainable_params(), learning_rate=0.001, weight_decay=weight_decay_schedule)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
||||
|
||||
|
||||
def test_lamb_dynamic_weight_decay_graph_group():
|
||||
"""
|
||||
Feature: Dynamic weight decay
|
||||
Description: Test dynamic weight decay for Momentum
|
||||
Expectation: The value of decay changes according to preset weight decay schedule
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
weight_decay_schedule = WeightDecaySchdule()
|
||||
net1, net2 = Net(), Net()
|
||||
|
||||
net1_fc1_params = list(filter(lambda x: 'fc1' in x.name, net1.trainable_params()))
|
||||
net1_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net1.trainable_params()))
|
||||
|
||||
net2_fc1_params = list(filter(lambda x: 'fc1' in x.name, net2.trainable_params()))
|
||||
net2_fc2_params = list(filter(lambda x: 'fc1' not in x.name, net2.trainable_params()))
|
||||
|
||||
params1 = [{'params': net1_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net1_fc2_params, 'weight_decay': 0.001, 'lr': 0.001}]
|
||||
|
||||
params2 = [{'params': net2_fc1_params, 'weight_decay': 0.01, 'lr': 0.01},
|
||||
{'params': net2_fc2_params, 'weight_decay': weight_decay_schedule, 'lr': 0.001}]
|
||||
|
||||
optimizer1 = nn.Lamb(params1, learning_rate=0.001, weight_decay=0.001)
|
||||
optimizer2 = nn.Lamb(params2, learning_rate=0.001, weight_decay=0.001)
|
||||
dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2)
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Dense, ReLU
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class WeightDecaySchdule(nn.Cell):
|
||||
def __init__(self):
|
||||
super(WeightDecaySchdule, self).__init__()
|
||||
self.weight_decay_list = Tensor([0.001, 0.001, 0.1], mstype.float32)
|
||||
|
||||
def construct(self, global_step):
|
||||
return self.weight_decay_list[global_step]
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.batch_size = 1
|
||||
self.reshape = P.Reshape()
|
||||
self.relu = ReLU()
|
||||
weight1 = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
|
||||
weight2 = Tensor(np.ones([10, 10]).astype(np.float32) * 0.02)
|
||||
weight3 = Tensor(np.ones([10, 10]).astype(np.float32) * 0.03)
|
||||
bias1 = Tensor(np.zeros(10).astype(np.float32))
|
||||
bias2 = Tensor(np.ones(10).astype(np.float32))
|
||||
bias3 = Tensor(np.ones(10).astype(np.float32))
|
||||
self.fc1 = Dense(16, 10, weight_init=weight1, bias_init=bias1)
|
||||
self.fc2 = Dense(10, 10, weight_init=weight2, bias_init=bias2)
|
||||
self.fc3 = Dense(10, 10, weight_init=weight3, bias_init=bias3)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.reshape(input_x, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc2(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
def dynamic_weight_decay_cmp(net1, net2, optimizer1, optimizer2):
|
||||
epoch = 3
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion1 = WithLossCell(net1, criterion)
|
||||
net_with_criterion2 = WithLossCell(net2, criterion)
|
||||
train_network1 = TrainOneStepCell(net_with_criterion1, optimizer1)
|
||||
train_network2 = TrainOneStepCell(net_with_criterion2, optimizer2)
|
||||
train_network1.set_train()
|
||||
train_network2.set_train()
|
||||
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array([0]).astype(np.int32))
|
||||
loss1 = train_network1(data, label)
|
||||
loss2 = train_network2(data, label)
|
||||
assert abs(loss1.asnumpy() - loss2.asnumpy()) < 1.e-8
|
||||
|
||||
data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array([0]).astype(np.int32))
|
||||
loss1 = net_with_criterion1(data, label)
|
||||
loss2 = net_with_criterion2(data, label)
|
||||
assert abs(loss1.asnumpy() - loss2.asnumpy()) > 1.e-8
|
|
@ -204,10 +204,10 @@ def test_weight_decay():
|
|||
for weight_decay, decay_flags, param, order_param in zip(
|
||||
opt.weight_decay, opt.decay_flags, opt.parameters, net.trainable_params()):
|
||||
if 'conv' in param.name:
|
||||
assert weight_decay == conv_weight_decay
|
||||
assert abs(weight_decay.asnumpy() - conv_weight_decay) < 1.e-6
|
||||
assert decay_flags is True
|
||||
else:
|
||||
assert weight_decay == default_weight_decay
|
||||
assert abs(weight_decay.asnumpy() - default_weight_decay) < 1.e-6
|
||||
assert decay_flags is False
|
||||
|
||||
assert param.name == order_param.name
|
||||
|
@ -305,15 +305,15 @@ def test_order_params_1():
|
|||
opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, bias_params+conv_params):
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy())
|
||||
assert weight_decay == 0.01
|
||||
assert abs(weight_decay.asnumpy() - 0.01) < 1.e-6
|
||||
assert decay_flags is True
|
||||
elif 'bias' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(0.01, mstype.float32).asnumpy())
|
||||
assert weight_decay == 0.0
|
||||
assert abs(weight_decay.asnumpy()) < 1.e-6
|
||||
assert decay_flags is False
|
||||
else:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy())
|
||||
assert weight_decay == 0.0
|
||||
assert abs(weight_decay.asnumpy()) < 1.e-6
|
||||
assert decay_flags is False
|
||||
|
||||
assert param.name == order_param.name
|
||||
|
@ -344,15 +344,15 @@ def test_order_params_2():
|
|||
opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params+conv_params):
|
||||
if 'conv' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy())
|
||||
assert weight_decay == conv_weight_decay
|
||||
assert abs(weight_decay.asnumpy() - conv_weight_decay) < 1.e-6
|
||||
assert decay_flags is True
|
||||
elif 'fc1' in param.name:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy())
|
||||
assert weight_decay == default_wd
|
||||
assert abs(weight_decay.asnumpy() - default_wd) < 1.e-6
|
||||
assert decay_flags is False
|
||||
else:
|
||||
assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy())
|
||||
assert weight_decay == default_wd
|
||||
assert abs(weight_decay.asnumpy() - default_wd) < 1.e-6
|
||||
assert decay_flags is False
|
||||
|
||||
assert param.name == order_param.name
|
||||
|
|
|
@ -419,7 +419,7 @@ class TestPipelineSplitWithNoOptimizer:
|
|||
self.cat_fp16_from_ir(pattern='grad_mirror_MirrorMicroStepOperator',
|
||||
target_count=2)
|
||||
self.cat_fp16_from_ir(pattern='Cast(',
|
||||
target_count=14)
|
||||
target_count=15)
|
||||
|
||||
def test_pipeline_with_micro_batch_no_parallel_optimizer(self):
|
||||
"""
|
||||
|
@ -438,7 +438,7 @@ class TestPipelineSplitWithNoOptimizer:
|
|||
self.cat_fp16_from_ir(pattern='grad_mirror_MirrorMicroStepOperator',
|
||||
target_count=2)
|
||||
self.cat_fp16_from_ir(pattern='Cast(',
|
||||
target_count=26)
|
||||
target_count=27)
|
||||
|
||||
def test_pipeline_split_stage0_device_num_48():
|
||||
"""
|
||||
|
|
|
@ -30,6 +30,16 @@ def vm_impl_assign(self):
|
|||
return x
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.AssignAdd)
|
||||
def vm_impl_assignadd(self):
|
||||
"""Generate vm_impl function for Assign"""
|
||||
def vm_impl(x, value, u=None):
|
||||
x.assign_value(value)
|
||||
return x
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.ExpandDims)
|
||||
def vm_impl_expand_dims(self):
|
||||
"""Generate vm_impl function for ExpandDims"""
|
||||
|
|
Loading…
Reference in New Issue