forked from mindspore-Ecosystem/mindspore
add fused adafactor primitive
This commit is contained in:
@ -50,8 +50,8 @@ class FusedAdaFactorCPUKernel : public CPUKernel {
bool enable_weight_decay_{false};
bool need_factor_{false};
size_t elem_num_{0};
size_t last_row_dim_size_{0};
size_t last_col_dim_size_{0};
size_t last_row_dim_size_{1};
size_t last_col_dim_size_{1};
TypeId param_dtype_{kTypeUnknown};
enum InputEnum {
@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.log import logging
from mindspore.common.initializer import initializer
@ -27,22 +28,6 @@ from mindspore.nn.optim.optimizer import opt_init_args_register
from .optimizer import Optimizer
def _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps):
"""update optimizer learning rete"""
rel_step_sz = learning_rate
if relative_step:
if warmup_init:
min_step = 1e-6 * step * 1.0
min_step = 1e-2 * 1.0
rel_step_sz = P.Minimum()(min_step, 1.0 / P.Sqrt()(step * 1.0))
param_scale = 1.0
if scale_parameter:
param_scale = P.Maximum()(eps[1], rms)
return rel_step_sz * param_scale * F.ones_like(rms)
def _rms(update_tensor):
"""calculate rms"""
return F.sqrt(P.ReduceMean(False)(F.square(update_tensor)))
@ -59,18 +44,14 @@ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
return P.Mul()(r_factor, c_factor)
_adam_opt = C.MultitypeFuncGraph("adam_opt")
_adafactor_opt = C.MultitypeFuncGraph("adafactor_opt")
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool",
"Bool", "Bool", "Bool", "Bool", "Bool", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
weight_decay, scale_lr, scale_parameter, relative_step,
warmup_init, compression, use_first_moment, weight_decay_flag,
learning_rate, step, grad, param,
exp_avg, exp_avg_sq_row,
exp_avg_sq_col, exp_avg_sq):
@_adafactor_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool", "Bool", "Bool", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, scale_parameter,
compression, use_first_moment, weight_decay_flag, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
"""Apply ada factor optimizer to the weight parameter using Tensor."""
success = True
grad_dtype = F.dtype(grad)
@ -84,38 +65,24 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
factored = len(grad_shape) >= 2
# State Initialization
exp_avg_update = exp_avg
exp_avg_sq_update = exp_avg_sq
exp_avg_sq_row_update = exp_avg_sq_row
exp_avg_sq_col_update = exp_avg_sq_col
if use_first_moment:
if compression:
exp_avg_update = F.cast(exp_avg, mstype.float16)
if factored:
exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
if scale_lr:
if scale_parameter:
rms = _rms(p_data_fp32)
learning_rate_update = _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps)
param_scale = P.Maximum()(eps[1], rms)
learning_rate_update = learning_rate * param_scale * F.ones_like(rms)
learning_rate_update = F.assign(learning_rate, F.cast(learning_rate_update, F.dtype(learning_rate)))
learning_rate_update = learning_rate * 1.0
learning_rate_update = learning_rate
beta2t = 1.0 - P.Pow()(step, decay_rate)
update = (grad ** 2) + eps[0]
if factored:
exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
update_mean = P.ReduceMean()(update, -1) * (1.0 - beta2t)
exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
update_mean = P.ReduceMean()(update, -2) * (1.0 - beta2t)
exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
@ -124,6 +91,7 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update)
update = P.Mul()(update, grad)
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
update = update * (1.0 - beta2t)
exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update)
exp_avg_sq_update = F.assign(exp_avg_sq, F.cast(exp_avg_sq_update, F.dtype(exp_avg_sq)))
@ -135,8 +103,9 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
update = P.Mul()(P.Div()(update, update_coff), learning_rate_update)
if use_first_moment:
exp_avg_update = exp_avg
if compression:
exp_avg_update = F.cast(exp_avg_update, grad_dtype)
exp_avg_update = F.cast(exp_avg, grad_dtype)
exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1))
update = F.assign(exp_avg, F.cast(exp_avg_update, F.dtype(exp_avg)))
@ -144,18 +113,27 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update
p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff)
p_data_fp32 = P.Sub()(p_data_fp32, update)
P.Assign()(param, F.cast(p_data_fp32, F.dtype(param)))
return success
return F.depend(success, P.Assign()(param, F.cast(p_data_fp32, F.dtype(param))))
def trans_to_tensor(paras, is_tuple=False, fp32=True):
if paras is None or isinstance(paras, bool):
return paras
@_adafactor_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor")
def _run_fused_ada_factor(fused_ada_factor, eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
success = True
ret = fused_ada_factor(eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq)
return F.depend(success, ret)
def trans_to_tensor(param, is_tuple=False, fp32=True):
if param is None or isinstance(param, bool):
return param
data_type = mstype.float32 if fp32 else mstype.float16
if is_tuple:
new_paras = [Tensor(ele, data_type) for ele in paras]
return tuple(new_paras)
return Tensor(paras, data_type)
new_param = [Tensor(ele, data_type) for ele in param]
return tuple(new_param)
return Tensor(param, data_type)
class AdaFactor(Optimizer):
@ -344,9 +322,17 @@ class AdaFactor(Optimizer):
self.relative_step = relative_step
self.warmup_init = warmup_init
self.compression = compression
if not self.scale_lr:
self.scale_parameter = False
self.step = Parameter(initializer(0, [1], mstype.float32), name='afactor_step')
self.fused_ada_factor = P.FusedAdaFactor(enable_scale_parameter=self.scale_parameter,
if context.get_context("device_target") == "CPU":
self.use_fused_ada_factor = True
self.use_fused_ada_factor = False
print("AdaFactor init completed", self.learning_rate)
def init_ada_factor_state(self, beta1):
@ -361,35 +347,31 @@ class AdaFactor(Optimizer):
self.exp_avg_sq = []
self.exp_avg_sq_col = []
self.exp_avg_sq_row = []
for paras in self.parameters:
paras_dtype = paras.dtype
paras_shape = paras.shape
paras_name =
if len(paras_shape) > 1:
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=paras_shape[:-1], dtype=paras_dtype),
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=paras_shape[:-2] + paras_shape[-1:],
if self.compression:
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=mstype.float16),
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
for param in self.parameters:
param_dtype = param.dtype
param_shape = param.shape
param_name =
if len(param_shape) > 1:
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=param_shape[:-1], dtype=param_dtype),
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=param_shape[:-2] + param_shape[-1:],
self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
if self.compression:
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=mstype.float16),
self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=mstype.float16),
self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=paras_dtype),
self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=param_dtype),
self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row)
self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col)
@ -406,13 +388,25 @@ class AdaFactor(Optimizer):
def construct(self, gradients):
lr = self.get_lr()
step = F.assign_add(self.step, 1)
success = self.hyper_map(F.partial(_adam_opt, self.eps, self.clip_threshold, self.decay_rate,
self.beta1, self.weight_decay, self.scale_lr,
self.scale_parameter, self.relative_step,
self.warmup_init, self.compression, self.use_first_moment,
self.weight_decay_flag, lr, step),
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
self.exp_avg_sq_col, self.exp_avg_sq)
if self.scale_lr and self.relative_step:
if self.warmup_init:
min_step = 1e-6 * step
min_step = 1e-2
lr = P.Minimum()(min_step, 1.0 / P.Sqrt()(step * 1.0))
beta2t = 1.0 - P.Pow()(step, self.decay_rate)
if self.use_fused_ada_factor:
success = self.hyper_map(F.partial(_adafactor_opt, self.fused_ada_factor, self.eps, self.clip_threshold,
self.beta1, beta2t, self.weight_decay, lr),
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
self.exp_avg_sq_col, self.exp_avg_sq)
success = self.hyper_map(F.partial(_adafactor_opt, self.eps, self.clip_threshold, self.beta1, beta2t,
self.weight_decay, self.scale_parameter, self.compression,
self.use_first_moment, self.weight_decay_flag, lr),
gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
self.exp_avg_sq_col, self.exp_avg_sq)
return success
@ -423,3 +417,8 @@ class AdaFactor(Optimizer):
optimizer operation.
if value == 'CPU':
self.fused_ada_factor.add_prim_attr("primitive_target", "CPU")
self.use_fused_ada_factor = True
self.use_fused_ada_factor = False
@ -44,7 +44,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
from .control_ops import GeSwitch, Merge
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay)
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr, Ger,
@ -177,6 +177,7 @@ __all__ = [
@ -254,6 +254,7 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
Supported Platforms:
def __init__(self):
"""Initialize LambApplyOptimizerAssign"""
@ -316,6 +317,7 @@ class LambApplyWeightAssign(PrimitiveWithInfer):
Supported Platforms:
def __init__(self):
"""Initialize LambApplyWeightAssign"""
@ -558,3 +560,132 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
"decay": decay_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32],, True)
return var_dtype, m_dtype, v_dtype
class FusedAdaFactor(PrimitiveWithInfer):
Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm.
The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory
Cost <>`_.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Adafactor for weight vector are as follows,
.. math::
\begin{array}{l} \\
\alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\
G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\
\hat{V}_{t}=\hat{\beta}_{2} \hat{V}_{t-1}+\left(1-\hat{\beta}_{2_{t}}\right)\left(G_{t}^{2}+ \\
\epsilon_{1} 1_{n}\right) \\
U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\
\hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\
X_{t}=X_{t-1}-\alpha_{t} \hat{U}_{t}
Adafactor for weight matrices are as follows,
.. math::
\begin{array}{l} \\
\alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\
G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\
R_{t}=\hat{\beta}_{2 t} R_{t-1}+\left(1-\hat{\beta}_{2 t}\right)\left(G_{t}^{2}+ \\
\epsilon_{1} 1_{n} 1_{m}^{\top}\right) 1_{m} \\
C_{t}=\hat{\beta}_{2 t} C_{t-1}+\left(1-\hat{\beta}_{2 t}\right) 1_{n}^{\top}\left(G_{t}^{2}+ \\
\epsilon_{1} 1_{n} 1_{m}^{\top}\right) \\
\hat{V}_{t}=R_{t} C_{t} / 1_{n}^{\top} R_{t} \\
U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\
\hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\
X_{t}=X_{t-1}-\alpha_{t} U_{t}
Where RMS is:
.. math::
\operatorname{RMS}\left(U_{t}\right)=\operatorname{RMS}_{x \in X}\left(u_{x t}\right)= \\
\sqrt{\operatorname{Mean}_{x \in X}\left(\frac{\left(g_{x t}\right)^{2}}{\hat{v}_{x t}}\right)}
:math:`x` is each individual parameter,
:math:`t` is assumed to be the current number of steps,
:math:`a_{t}` is the learning rate,
:math:`f(X)` is the loss function,
:math:`\epsilon1` and :math:`\epsilon2` is a small positive number to prevent errors,
:math:`d` is the clipping threshold,
:math:`\beta_{2}` is the moment decay,
:math:`\rho` is the relative step size,
:math:`R` is the running averages of the row sums of the squared gradient,
:math:`C` is the running averages of the column sums of the squared gradient.
enable_weight_decay (bool): If True, enable weight decay. default: False
enable_first_moment (bool): If True, enable first moment. default: False
enable_scale_parameter (bool): If True, enable scale learning rate using parameter. default: False
- **epsilon** (Tensor) - input epsilon pair.
- **clip_threshold** (float) - The threshold of root mean square of final gradient update.
- **beta1** (float) - The exponential decay rate for the 1nd moment estimations.
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations.
- **weight_decay** (float) - The weight decay value, must be a scalar tensor with float data type.
- **learning_rate** (float) - The learning rate value.
- **gradient** (Tensor) - Gradient.
- **param** (Tensor) - Weights to be updated.
- **exp_avg** (Tensor) - The exponential moving average of 1st moment optimizer state.
- **exp_avg_sq_row** (Tensor) - The exponential moving average of square of gradient square row factor.
- **exp_avg_sq_col** (Tensor) - The exponential moving average of square of gradient square col factor.
- **exp_avg_sq** (Tensor) - The exponential moving average of square of gradient square.
- **dummy_param** (Tensor) - The same shape and data type as `param`.
Supported Platforms:
>>> import numpy as np
>>> import mindspore.context as context
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> from mindspore import Tensor, Parameter
>>> from mindspore import dtype as mstype
>>> param_shape = [2, 3, 2]
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.opt = ops.FusedAdaFactor()
... self.param = Parameter(Tensor(np.ones(param_shape), mstype.float32), name="param")
... self.exp_avg = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg")
... self.exp_avg_sq = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg_sq")
... self.exp_avg_sq_row = Parameter(Tensor(np.zeros([2, 3]), mstype.float32), name="exp_avg_sq_row")
... self.exp_avg_sq_col = Parameter(Tensor(np.zeros([2, 2]), mstype.float32), name="exp_avg_sq_col")
... def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad):
... out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param,
... self.exp_avg, self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq)
... return out
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> net = Net()
>>> gradient = Tensor(np.ones(param_shape), mstype.float32)
>>> net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient)
>>> print(net.param.asnumpy())
def __init__(self, enable_scale_parameter=False, enable_first_moment=False, enable_weight_decay=False):
self.add_prim_attr('side_effect_mem', True)
validator.check_value_type("enable_scale_parameter", enable_scale_parameter, [bool],
validator.check_value_type("enable_first_moment", enable_first_moment, [bool],
validator.check_value_type("enable_weight_decay", enable_weight_decay, [bool],
def infer_shape(self, epsilon_shape, clip_threshold_shape, beta1_shape, beta2t_shape, weight_decay_shape,
learning_rate_shape, grad_shape, param_shape, exp_avg_shape, exp_avg_sq_row_shape,
exp_avg_sq_col_shape, exp_avg_sq_shape):
validator.check("grad_shape", grad_shape, "param_shape", param_shape, Rel.EQ,
return param_shape
def infer_dtype(self, epsilon_type, clip_threshold_type, beta1_type, beta2t_type, weight_decay_type,
learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type,
exp_avg_sq_col_type, exp_avg_sq_type):
return param_type
@ -0,0 +1,56 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http: //
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
import pytest
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
param_shape = [2, 3, 2]
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.opt = ops.FusedAdaFactor()
self.param = Parameter(Tensor(np.ones(param_shape), mstype.float32), name="param")
self.exp_avg = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg")
self.exp_avg_sq = Parameter(Tensor(np.zeros(param_shape), mstype.float32), name="exp_avg_sq")
self.exp_avg_sq_row = Parameter(Tensor(np.zeros([2, 3]), mstype.float32), name="exp_avg_sq_row")
self.exp_avg_sq_col = Parameter(Tensor(np.zeros([2, 2]), mstype.float32), name="exp_avg_sq_col")
def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad):
out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param, self.exp_avg,
self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq)
return out
def test_adafactor():
Feature: AdaFactor
Description: Test AdaFactor
Expectation: Run success
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = Net()
gradient = Tensor(np.ones(param_shape), mstype.float32)
net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient)
diff = net.param.asnumpy() - np.ones(param_shape) * 0.97
assert np.all(diff < 1e-3)
@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from import LeNet
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def test_lenet():
Feature: AdaFactor
Description: Test AdaFactor
Expectation: Run lenet success
data = Tensor(np.ones([32, 3, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
net.batch_size = 32
learning_rate = 0.01
optimizer = nn.AdaFactor(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate,
scale_parameter=False, relative_step=False, beta1=0)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
loss = []
for _ in range(10):
res = train_network(data, label)
assert np.all(loss[-1] < 0.1)
Reference in New Issue