forked from mindspore-Ecosystem/mindspore
add fused adafactor primitive
This commit is contained in:
parent
45bc08fd8f
commit
1b37d4302d
|
@ -50,8 +50,8 @@ class FusedAdaFactorCPUKernel : public CPUKernel {
|
|||
bool enable_weight_decay_{false};
|
||||
bool need_factor_{false};
|
||||
size_t elem_num_{0};
|
||||
size_t last_row_dim_size_{0};
|
||||
size_t last_col_dim_size_{0};
|
||||
size_t last_row_dim_size_{1};
|
||||
size_t last_col_dim_size_{1};
|
||||
TypeId param_dtype_{kTypeUnknown};
|
||||
|
||||
enum InputEnum {
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""adafactor"""
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.log import logging
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -27,22 +28,6 @@ from mindspore.nn.optim.optimizer import opt_init_args_register
|
|||
from .optimizer import Optimizer
|
||||
|
||||
|
||||
def _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps):
|
||||
"""update optimizer learning rete"""
|
||||
rel_step_sz = learning_rate
|
||||
if relative_step:
|
||||
if warmup_init:
|
||||
min_step = 1e-6 * step * 1.0
|
||||
else:
|
||||
min_step = 1e-2 * 1.0
|
||||
|
||||
rel_step_sz = P.Minimum()(min_step, 1.0 / P.Sqrt()(step * 1.0))
|
||||
param_scale = 1.0
|
||||
if scale_parameter:
|
||||
param_scale = P.Maximum()(eps[1], rms)
|
||||
return rel_step_sz * param_scale * F.ones_like(rms)
|
||||
|
||||
|
||||
def _rms(update_tensor):
|
||||
"""calculate rms"""
|
||||
return F.sqrt(P.ReduceMean(False)(F.square(update_tensor)))
|
||||
|
@ -59,18 +44,14 @@ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
|||
return P.Mul()(r_factor, c_factor)
|
||||
|
||||
|
||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
_adafactor_opt = C.MultitypeFuncGraph("adafactor_opt")
|
||||
|
||||
|
||||
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool",
|
||||
"Bool", "Bool", "Bool", "Bool", "Bool", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
|
||||
weight_decay, scale_lr, scale_parameter, relative_step,
|
||||
warmup_init, compression, use_first_moment, weight_decay_flag,
|
||||
learning_rate, step, grad, param,
|
||||
exp_avg, exp_avg_sq_row,
|
||||
exp_avg_sq_col, exp_avg_sq):
|
||||
@_adafactor_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool", "Bool", "Bool", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, scale_parameter,
|
||||
compression, use_first_moment, weight_decay_flag, learning_rate,
|
||||
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
|
||||
"""Apply ada factor optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
grad_dtype = F.dtype(grad)
|
||||
|
@ -84,38 +65,24 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
|
|||
|
||||
factored = len(grad_shape) >= 2
|
||||
|
||||
# State Initialization
|
||||
exp_avg_update = exp_avg
|
||||
exp_avg_sq_update = exp_avg_sq
|
||||
exp_avg_sq_row_update = exp_avg_sq_row
|
||||
exp_avg_sq_col_update = exp_avg_sq_col
|
||||
|
||||
if use_first_moment:
|
||||
if compression:
|
||||
exp_avg_update = F.cast(exp_avg, mstype.float16)
|
||||
|
||||
if factored:
|
||||
exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
|
||||
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
|
||||
else:
|
||||
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
|
||||
|
||||
if scale_lr:
|
||||
if scale_parameter:
|
||||
rms = _rms(p_data_fp32)
|
||||
learning_rate_update = _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps)
|
||||
param_scale = P.Maximum()(eps[1], rms)
|
||||
learning_rate_update = learning_rate * param_scale * F.ones_like(rms)
|
||||
learning_rate_update = F.assign(learning_rate, F.cast(learning_rate_update, F.dtype(learning_rate)))
|
||||
else:
|
||||
learning_rate_update = learning_rate * 1.0
|
||||
learning_rate_update = learning_rate
|
||||
|
||||
beta2t = 1.0 - P.Pow()(step, decay_rate)
|
||||
update = (grad ** 2) + eps[0]
|
||||
|
||||
if factored:
|
||||
exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
|
||||
exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
|
||||
update_mean = P.ReduceMean()(update, -1) * (1.0 - beta2t)
|
||||
exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
|
||||
exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
|
||||
|
||||
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
|
||||
exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
|
||||
update_mean = P.ReduceMean()(update, -2) * (1.0 - beta2t)
|
||||
exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
|
||||
|
@ -124,6 +91,7 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
|
|||
update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update)
|
||||
update = P.Mul()(update, grad)
|
||||
else:
|
||||
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
|
||||
update = update * (1.0 - beta2t)
|
||||
exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update)
|
||||
exp_avg_sq_update = F.assign(exp_avg_sq, F.cast(exp_avg_sq_update, F.dtype(exp_avg_sq)))
|
||||
|
@ -135,8 +103,9 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
|
|||
update = P.Mul()(P.Div()(update, update_coff), learning_rate_update)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg_update = exp_avg
|
||||
if compression:
|
||||
exp_avg_update = F.cast(exp_avg_update, grad_dtype)
|
||||
exp_avg_update = F.cast(exp_avg, grad_dtype)
|
||||
exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1))
|
||||
update = F.assign(exp_avg, F.cast(exp_avg_update, F.dtype(exp_avg)))
|
||||
|
||||
|
@ -144,18 +113,27 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
|
|||
p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update
|
||||
p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff)
|
||||
p_data_fp32 = P.Sub()(p_data_fp32, update)
|
||||
P.Assign()(param, F.cast(p_data_fp32, F.dtype(param)))
|
||||
return success
|
||||
return F.depend(success, P.Assign()(param, F.cast(p_data_fp32, F.dtype(param))))
|
||||
|
||||
|
||||
def trans_to_tensor(paras, is_tuple=False, fp32=True):
|
||||
if paras is None or isinstance(paras, bool):
|
||||
return paras
|
||||
@_adafactor_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _run_fused_ada_factor(fused_ada_factor, eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
|
||||
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
|
||||
success = True
|
||||
ret = fused_ada_factor(eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
|
||||
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq)
|
||||
return F.depend(success, ret)
|
||||
|
||||
|
||||
def trans_to_tensor(param, is_tuple=False, fp32=True):
|
||||
if param is None or isinstance(param, bool):
|
||||
return param
|
||||
data_type = mstype.float32 if fp32 else mstype.float16
|
||||
if is_tuple:
|
||||
new_paras = [Tensor(ele, data_type) for ele in paras]
|
||||
return tuple(new_paras)
|
||||
return Tensor(paras, data_type)
|
||||
new_param = [Tensor(ele, data_type) for ele in param]
|
||||
return tuple(new_param)
|
||||
return Tensor(param, data_type)
|
||||
|
||||
|
||||
class AdaFactor(Optimizer):
|
||||
|
@ -344,9 +322,17 @@ class AdaFactor(Optimizer):
|
|||
self.relative_step = relative_step
|
||||
self.warmup_init = warmup_init
|
||||
self.compression = compression
|
||||
|
||||
if not self.scale_lr:
|
||||
self.scale_parameter = False
|
||||
self.init_ada_factor_state(beta1)
|
||||
self.step = Parameter(initializer(0, [1], mstype.float32), name='afactor_step')
|
||||
self.fused_ada_factor = P.FusedAdaFactor(enable_scale_parameter=self.scale_parameter,
|
||||
enable_first_moment=self.use_first_moment,
|
||||
enable_weight_decay=self.weight_decay_flag)
|
||||
if context.get_context("device_target") == "CPU":
|
||||
self.use_fused_ada_factor = True
|
||||
else:
|
||||
self.use_fused_ada_factor = False
|
||||
print("AdaFactor init completed", self.learning_rate)
|
||||
|
||||
def init_ada_factor_state(self, beta1):
|
||||
|
@ -361,35 +347,31 @@ class AdaFactor(Optimizer):
|
|||
self.exp_avg_sq = []
|
||||
self.exp_avg_sq_col = []
|
||||
self.exp_avg_sq_row = []
|
||||
for paras in self.parameters:
|
||||
paras_dtype = paras.dtype
|
||||
paras_shape = paras.shape
|
||||
paras_name = paras.name
|
||||
if len(paras_shape) > 1:
|
||||
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=paras_shape[:-1], dtype=paras_dtype),
|
||||
name="exp_avg_sq_row_{}".format(paras_name)))
|
||||
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=paras_shape[:-2] + paras_shape[-1:],
|
||||
dtype=paras_dtype),
|
||||
name="exp_avg_sq_col_{}".format(paras_name)))
|
||||
if self.compression:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=mstype.float16),
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
else:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
for param in self.parameters:
|
||||
param_dtype = param.dtype
|
||||
param_shape = param.shape
|
||||
param_name = param.name
|
||||
if len(param_shape) > 1:
|
||||
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=param_shape[:-1], dtype=param_dtype),
|
||||
name="exp_avg_sq_row_{}".format(param_name)))
|
||||
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=param_shape[:-2] + param_shape[-1:],
|
||||
dtype=param_dtype),
|
||||
name="exp_avg_sq_col_{}".format(param_name)))
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
|
||||
name="exp_avg_sq_{}".format(param_name)))
|
||||
|
||||
else:
|
||||
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
|
||||
name="exp_avg_sq_row_{}".format(paras_name)))
|
||||
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
|
||||
name="exp_avg_sq_col_{}".format(paras_name)))
|
||||
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
|
||||
name="exp_avg_sq_row_{}".format(param_name)))
|
||||
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
|
||||
name="exp_avg_sq_col_{}".format(param_name)))
|
||||
|
||||
if self.compression:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=mstype.float16),
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=mstype.float16),
|
||||
name="exp_avg_sq_{}".format(param_name)))
|
||||
else:
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=paras_dtype),
|
||||
name="exp_avg_sq_{}".format(paras_name)))
|
||||
self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=param_dtype),
|
||||
name="exp_avg_sq_{}".format(param_name)))
|
||||
|
||||
self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row)
|
||||
self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col)
|
||||
|
@ -406,13 +388,25 @@ class AdaFactor(Optimizer):
|
|||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
step = F.assign_add(self.step, 1)
|
||||
success = self.hyper_map(F.partial(_adam_opt, self.eps, self.clip_threshold, self.decay_rate,
|
||||
self.beta1, self.weight_decay, self.scale_lr,
|
||||
self.scale_parameter, self.relative_step,
|
||||
self.warmup_init, self.compression, self.use_first_moment,
|
||||
self.weight_decay_flag, lr, step),
|
||||
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
|
||||
self.exp_avg_sq_col, self.exp_avg_sq)
|
||||
if self.scale_lr and self.relative_step:
|
||||
if self.warmup_init:
|
||||
min_step = 1e-6 * step
|
||||
else:
|
||||
min_step = 1e-2
|
||||
lr = P.Minimum()(min_step, 1.0 / P.Sqrt()(step * 1.0))
|
||||
beta2t = 1.0 - P.Pow()(step, self.decay_rate)
|
||||
|
||||
if self.use_fused_ada_factor:
|
||||
success = self.hyper_map(F.partial(_adafactor_opt, self.fused_ada_factor, self.eps, self.clip_threshold,
|
||||
self.beta1, beta2t, self.weight_decay, lr),
|
||||
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
|
||||
self.exp_avg_sq_col, self.exp_avg_sq)
|
||||
else:
|
||||
success = self.hyper_map(F.partial(_adafactor_opt, self.eps, self.clip_threshold, self.beta1, beta2t,
|
||||
self.weight_decay, self.scale_parameter, self.compression,
|
||||
self.use_first_moment, self.weight_decay_flag, lr),
|
||||
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
|
||||
self.exp_avg_sq_col, self.exp_avg_sq)
|
||||
|
||||
return success
|
||||
|
||||
|
@ -423,3 +417,8 @@ class AdaFactor(Optimizer):
|
|||
optimizer operation.
|
||||
"""
|
||||
self._set_base_target(value)
|
||||
if value == 'CPU':
|
||||
self.fused_ada_factor.add_prim_attr("primitive_target", "CPU")
|
||||
self.use_fused_ada_factor = True
|
||||
else:
|
||||
self.use_fused_ada_factor = False
|
||||
|
|
|
@ -44,7 +44,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
|
|||
from .control_ops import GeSwitch, Merge
|
||||
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
|
||||
MakeRefKey,
|
||||
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay)
|
||||
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor)
|
||||
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
||||
BitwiseAnd, BitwiseOr, Ger,
|
||||
|
@ -177,6 +177,7 @@ __all__ = [
|
|||
'FusedSparseAdam',
|
||||
'FusedSparseLazyAdam',
|
||||
'AdamNoUpdateParam',
|
||||
'FusedAdaFactor',
|
||||
'Softplus',
|
||||
'Softmax',
|
||||
'Softsign',
|
||||
|
|
|
@ -254,6 +254,7 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
|
|||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize LambApplyOptimizerAssign"""
|
||||
|
@ -316,6 +317,7 @@ class LambApplyWeightAssign(PrimitiveWithInfer):
|
|||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize LambApplyWeightAssign"""
|
||||
|
@ -558,3 +560,132 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
|
|||
"decay": decay_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
||||
class FusedAdaFactor(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm.
|
||||
|
||||
The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory
|
||||
Cost <https://arxiv.org/abs/1804.04235>`_.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Adafactor for weight vector are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{l} \\
|
||||
\alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\
|
||||
G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\
|
||||
\hat{V}_{t}=\hat{\beta}_{2} \hat{V}_{t-1}+\left(1-\hat{\beta}_{2_{t}}\right)\left(G_{t}^{2}+ \\
|
||||
\epsilon_{1} 1_{n}\right) \\
|
||||
U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\
|
||||
\hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\
|
||||
X_{t}=X_{t-1}-\alpha_{t} \hat{U}_{t}
|
||||
\end{array}
|
||||
|
||||
Adafactor for weight matrices are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{l} \\
|
||||
\alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\
|
||||
G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\
|
||||
R_{t}=\hat{\beta}_{2 t} R_{t-1}+\left(1-\hat{\beta}_{2 t}\right)\left(G_{t}^{2}+ \\
|
||||
\epsilon_{1} 1_{n} 1_{m}^{\top}\right) 1_{m} \\
|
||||
C_{t}=\hat{\beta}_{2 t} C_{t-1}+\left(1-\hat{\beta}_{2 t}\right) 1_{n}^{\top}\left(G_{t}^{2}+ \\
|
||||
\epsilon_{1} 1_{n} 1_{m}^{\top}\right) \\
|
||||
\hat{V}_{t}=R_{t} C_{t} / 1_{n}^{\top} R_{t} \\
|
||||
U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\
|
||||
\hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\
|
||||
X_{t}=X_{t-1}-\alpha_{t} U_{t}
|
||||
\end{array}
|
||||
|
||||
Where RMS is:
|
||||
|
||||
.. math::
|
||||
\operatorname{RMS}\left(U_{t}\right)=\operatorname{RMS}_{x \in X}\left(u_{x t}\right)= \\
|
||||
\sqrt{\operatorname{Mean}_{x \in X}\left(\frac{\left(g_{x t}\right)^{2}}{\hat{v}_{x t}}\right)}
|
||||
|
||||
:math:`x` is each individual parameter,
|
||||
:math:`t` is assumed to be the current number of steps,
|
||||
:math:`a_{t}` is the learning rate,
|
||||
:math:`f(X)` is the loss function,
|
||||
:math:`\epsilon1` and :math:`\epsilon2` is a small positive number to prevent errors,
|
||||
:math:`d` is the clipping threshold,
|
||||
:math:`\beta_{2}` is the moment decay,
|
||||
:math:`\rho` is the relative step size,
|
||||
:math:`R` is the running averages of the row sums of the squared gradient,
|
||||
:math:`C` is the running averages of the column sums of the squared gradient.
|
||||
|
||||
Args:
|
||||
enable_weight_decay (bool): If True, enable weight decay. default: False
|
||||
enable_first_moment (bool): If True, enable first moment. default: False
|
||||
enable_scale_parameter (bool): If True, enable scale learning rate using parameter. default: False
|
||||
|
||||
Inputs:
|
||||
- **epsilon** (Tensor) - input epsilon pair.
|
||||
- **clip_threshold** (float) - The threshold of root mean square of final gradient update.
|
||||
- **beta1** (float) - The exponential decay rate for the 1nd moment estimations.
|
||||
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations.
|
||||
- **weight_decay** (float) - The weight decay value, must be a scalar tensor with float data type.
|
||||
- **learning_rate** (float) - The learning rate value.
|
||||
- **gradient** (Tensor) - Gradient.
|
||||
- **param** (Tensor) - Weights to be updated.
|
||||
- **exp_avg** (Tensor) - The exponential moving average of 1st moment optimizer state.
|
||||
- **exp_avg_sq_row** (Tensor) - The exponential moving average of square of gradient square row factor.
|
||||
- **exp_avg_sq_col** (Tensor) - The exponential moving average of square of gradient square col factor.
|
||||
- **exp_avg_sq** (Tensor) - The exponential moving average of square of gradient square.
|
||||
|
||||
Outputs:
|
||||
- **dummy_param** (Tensor) - The same shape and data type as `param`.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.context as context
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> param_shape = [2, 3, 2]
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.opt = ops.FusedAdaFactor()
|
||||
... self.param = Parameter(Tensor(np.ones(param_shape), mstype.float32), name="param")
|
||||
... self.exp_avg = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg")
|
||||
... self.exp_avg_sq = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg_sq")
|
||||
... self.exp_avg_sq_row = Parameter(Tensor(np.zeros([2, 3]), mstype.float32), name="exp_avg_sq_row")
|
||||
... self.exp_avg_sq_col = Parameter(Tensor(np.zeros([2, 2]), mstype.float32), name="exp_avg_sq_col")
|
||||
...
|
||||
... def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad):
|
||||
... out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param,
|
||||
... self.exp_avg, self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq)
|
||||
... return out
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
>>> net = Net()
|
||||
>>> gradient = Tensor(np.ones(param_shape), mstype.float32)
|
||||
>>> net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient)
|
||||
>>> print(net.param.asnumpy())
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, enable_scale_parameter=False, enable_first_moment=False, enable_weight_decay=False):
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
validator.check_value_type("enable_scale_parameter", enable_scale_parameter, [bool], self.name)
|
||||
validator.check_value_type("enable_first_moment", enable_first_moment, [bool], self.name)
|
||||
validator.check_value_type("enable_weight_decay", enable_weight_decay, [bool], self.name)
|
||||
|
||||
def infer_shape(self, epsilon_shape, clip_threshold_shape, beta1_shape, beta2t_shape, weight_decay_shape,
|
||||
learning_rate_shape, grad_shape, param_shape, exp_avg_shape, exp_avg_sq_row_shape,
|
||||
exp_avg_sq_col_shape, exp_avg_sq_shape):
|
||||
validator.check("grad_shape", grad_shape, "param_shape", param_shape, Rel.EQ, self.name)
|
||||
return param_shape
|
||||
|
||||
def infer_dtype(self, epsilon_type, clip_threshold_type, beta1_type, beta2t_type, weight_decay_type,
|
||||
learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type,
|
||||
exp_avg_sq_col_type, exp_avg_sq_type):
|
||||
return param_type
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http: // www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
param_shape = [2, 3, 2]
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.opt = ops.FusedAdaFactor()
|
||||
self.param = Parameter(Tensor(np.ones(param_shape), mstype.float32), name="param")
|
||||
self.exp_avg = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg")
|
||||
self.exp_avg_sq = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg_sq")
|
||||
self.exp_avg_sq_row = Parameter(Tensor(np.zeros([2, 3]), mstype.float32), name="exp_avg_sq_row")
|
||||
self.exp_avg_sq_col = Parameter(Tensor(np.zeros([2, 2]), mstype.float32), name="exp_avg_sq_col")
|
||||
|
||||
def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad):
|
||||
out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param, self.exp_avg,
|
||||
self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_adafactor():
|
||||
'''
|
||||
Feature: AdaFactor
|
||||
Description: Test AdaFactor
|
||||
Expectation: Run success
|
||||
'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = Net()
|
||||
gradient = Tensor(np.ones(param_shape), mstype.float32)
|
||||
net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient)
|
||||
diff = net.param.asnumpy() - np.ones(param_shape) * 0.97
|
||||
assert np.all(diff < 1e-3)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from tests.st.networks.models.lenet import LeNet
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet():
|
||||
'''
|
||||
Feature: AdaFactor
|
||||
Description: Test AdaFactor
|
||||
Expectation: Run lenet success
|
||||
'''
|
||||
data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([32]).astype(np.int32))
|
||||
net = LeNet()
|
||||
net.batch_size = 32
|
||||
learning_rate = 0.01
|
||||
optimizer = nn.AdaFactor(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate,
|
||||
scale_parameter=False, relative_step=False, beta1=0)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
loss = []
|
||||
for _ in range(10):
|
||||
res = train_network(data, label)
|
||||
loss.append(res.asnumpy())
|
||||
assert np.all(loss[-1] < 0.1)
|
Loading…
Reference in New Issue