sparse optimizer

This commit is contained in:
Jiaqi 2020-10-09 17:00:07 +08:00
parent 479cb89e8b
commit a30ccea62c
13 changed files with 360 additions and 41 deletions

View File

@ -27,6 +27,8 @@ from mindspore._checkparam import Rel
from .optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
@ -85,31 +87,80 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
return gradient
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr,
gradient, params, moment1, moment2, ps_parameter):
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter:
op_shape = P.Shape()
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
shapes = (op_shape(params), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices), shapes), params))
else:
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
return success
if not target:
success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices))
else:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
scatter_add = P.ScatterAdd(use_locking)
assign_m = F.assign(m, op_mul(beta1, m))
assign_v = F.assign(v, op_mul(beta2, v))
grad_indices = gradient.indices
grad_value = gradient.values
next_m = scatter_add(m,
grad_indices,
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
next_v = scatter_add(v,
grad_indices,
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
if use_nesterov:
m_temp = next_m * _scaler_ten
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
div_value = scatter_add(m,
op_mul(grad_indices, _scaler_one),
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
param_update = div_value / (op_sqrt(next_v) + eps)
m_recover = F.assign(m, m_temp / _scaler_ten)
F.control_depend(m_temp, assign_m_nesterov)
F.control_depend(assign_m_nesterov, div_value)
F.control_depend(param_update, m_recover)
else:
param_update = next_m / (op_sqrt(next_v) + eps)
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
next_param = params - lr_t * param_update
F.control_depend(assign_m, next_m)
F.control_depend(assign_v, next_v)
success = F.depend(success, F.assign(params, next_param))
success = F.depend(success, F.assign(m, next_m))
success = F.depend(success, F.assign(v, next_v))
return success
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient,
params, moment1, moment2, ps_parameter):
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter:
@ -161,8 +212,8 @@ class Adam(Optimizer):
To improve parameter groups performance, the customized order of parameters is supported.
The sparse strategy is applied while the SparseGatherV2 operator is used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host,
set the target to the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
@ -242,14 +293,16 @@ class Adam(Optimizer):
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
self.eps = Tensor(eps, mstype.float32)
self.use_nesterov = use_nesterov
self.use_locking = use_locking
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self._is_device = True
self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive", "CPU")
self._ps_pull = P.Pull()
self._ps_push = P.Push("Adam", [0, 1, 2])
self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
@ -260,6 +313,7 @@ class Adam(Optimizer):
moment2 = self.moment2
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients)
lr = self.get_lr()
beta1_power = self.beta1_power * self.beta1
@ -268,14 +322,26 @@ class Adam(Optimizer):
self.beta2_power = beta2_power
if self.is_group_lr:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, params, moment1, moment2, self.ps_parameters)
else:
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, params, moment1, moment2, self.ps_parameters)
return success
@Optimizer.target.setter
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
self._is_device = (value != 'CPU')
self._target = value
class AdamWeightDecay(Optimizer):
"""

View File

@ -89,7 +89,8 @@ class FTRL(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU.
The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host,
set the target to the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
@ -154,12 +155,14 @@ class FTRL(Optimizer):
self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1
self.l2 = l2
self.lr = learning_rate
self.lr_power = lr_power
if not self.is_group:
self.decay_flags = tuple((lambda: True)() for x in self.parameters)
self.hyper_map = C.HyperMap()
self.opt = P.ApplyFtrl(use_locking=use_locking)
self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
self.use_locking = use_locking
self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
self._ps_pull = P.Pull()
self._ps_push = P.Push("Ftrl", [0, 1, 2])
self._ps_push.add_prim_attr("init_accum", initial_accum)
@ -174,9 +177,26 @@ class FTRL(Optimizer):
linear = self.linear
grads = self.decay_weight(grads)
grads = self.scale_grad(grads)
grads = self._grad_sparse_indices_deduplicate(grads)
lr = self.get_lr()
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.l1, self.l2, self.lr_power, lr),
linear, grads, params, moments, self.ps_parameters)
return success
@Optimizer.target.setter
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
if value == 'CPU':
self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
self.sparse_opt.add_prim_attr("primitive", "CPU")
else:
self.sparse_opt = P.SparseApplyFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
self._target = value

View File

@ -27,31 +27,57 @@ from .optimizer import Optimizer
_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps,
lr, gradient, params, moment1, moment2, ps_parameter):
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter:
op_shape = P.Shape()
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
shapes = (op_shape(params), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices), shapes), params))
else:
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
return success
if not target:
success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices))
else:
op_gather = P.GatherV2()
op_sqrt = P.Sqrt()
scatter_add = P.ScatterAdd(use_locking)
scatter_update = P.ScatterUpdate(use_locking)
m_slice = op_gather(m, indices, 0)
v_slice = op_gather(v, indices, 0)
next_m = m_slice * beta1 + values * (1 - beta1)
next_v = v_slice * beta2 + values * values * (1 - beta2)
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
if use_nesterov:
m_temp = beta1 * next_m + values * (1 - beta1)
param_update = m_temp / (op_sqrt(next_v) + eps)
else:
param_update = next_m / (op_sqrt(next_v) + eps)
success = F.depend(success, scatter_add(params, indices, - lr_t * param_update))
success = F.depend(success, scatter_update(m, indices, next_m))
success = F.depend(success, scatter_update(v, indices, next_v))
return success
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps,
lr, gradient, params, moment1, moment2, ps_parameter):
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
"""Apply lazy adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter:
@ -108,7 +134,7 @@ class LazyAdam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU.
continuous development. If the sparse strategy wants to be executed on the host, set the target to the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
@ -191,14 +217,14 @@ class LazyAdam(Optimizer):
self.eps = Tensor(eps, mstype.float32)
self.use_nesterov = use_nesterov
self.use_locking = use_locking
self._is_device = True
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive", "CPU")
self._ps_pull = P.Pull()
self._ps_push = P.Push("Adam", [0, 1, 2])
self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
@ -206,6 +232,7 @@ class LazyAdam(Optimizer):
def construct(self, gradients):
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients)
lr = self.get_lr()
self.beta1_power = self.beta1_power * self.beta1
@ -213,10 +240,22 @@ class LazyAdam(Optimizer):
if self.is_group_lr:
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
else:
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
return success
@Optimizer.target.setter
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
self._is_device = (value != 'CPU')
self._target = value

View File

@ -26,7 +26,6 @@ from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor, RowTensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.context import ParallelMode
@ -105,6 +104,8 @@ class Optimizer(Cell):
weight_decay = self._preprocess_weight_decay(weight_decay)
self._unique = True
self._target = 'Ascend'
self.dynamic_lr = False
self.assignadd = None
self.global_step = None
@ -173,6 +174,30 @@ class Optimizer(Cell):
else:
self.optim_filter = (True,) * self.param_length
@property
def unique(self):
"""This method is to see whether to make uniqueThis method is read-only."""
return self._unique
@unique.setter
def unique(self, value):
"""Set whether the input value is unique."""
if not isinstance(value, bool):
raise TypeError("The value type must be bool, but got value type is {}".format(type(value)))
self._unique = value
@property
def target(self):
"""This method is used to determine the value of target and whether the parameter update is performed on
the host or device. This method is read-only."""
return self._target
@target.setter
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
raise NotImplementedError
def decay_weight(self, gradients):
"""
Weight decay.
@ -217,6 +242,12 @@ class Optimizer(Cell):
return gradients
def _grad_sparse_indices_deduplicate(self, gradients):
""" In the case of using big operators, de duplicate the 'indexes' in gradients."""
if self._target != 'CPU' and self._unique:
gradients = self.map_(F.partial(_indices_deduplicate), gradients)
return gradients
def _preprocess_weight_decay(self, weight_decay):
"""Check weight decay, and convert int to float."""
if isinstance(weight_decay, (float, int)):
@ -514,7 +545,7 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
_grad_scale = C.MultitypeFuncGraph("grad_scale")
_indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
@_grad_scale.register("Number", "Tensor")
def tensor_grad_scale(scale, grad):
@ -532,6 +563,24 @@ def tensor_grad_scale_with_sparse(scale, grad):
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)
@_indices_deduplicate.register("RowTensor")
def rowtensor_deduplicate_indices_slices(grad):
"""Unique the indices and sums the 'values' corresponding to the duplicate indices."""
indices = grad.indices
values = grad.values
unique_indices, index_position = P.Unique()(indices)
summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0])
return RowTensor(unique_indices, summed_values, grad.dense_shape)
@_indices_deduplicate.register("Tensor")
def tensor_deduplicate_indice_slices(grad):
"""Return the input gradient directly in the dense sences."""
return grad
class _ConvertToCell(LearningRateSchedule):
"""Inner api, convert learning rate of scalar to LearningRateSchedule."""
def __init__(self, learning_rate):

View File

@ -17,7 +17,6 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@ -66,8 +65,8 @@ class ProximalAdagrad(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU.
The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host,
set the target to the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
@ -136,14 +135,16 @@ class ProximalAdagrad(Optimizer):
self.l1 = Tensor(l1, mstype.float32)
self.l2 = Tensor(l2, mstype.float32)
self.hyper_map = C.HyperMap()
self.use_locking = use_locking
self.opt = P.ApplyProximalAdagrad(use_locking=use_locking)
self.sparse_opt = P.FusedSparseProximalAdagrad(use_locking=use_locking)
self.sparse_opt = P.SparseApplyProximalAdagrad(use_locking=use_locking)
def construct(self, grads):
params = self.parameters
accum = self.accum
grads = self.decay_weight(grads)
grads = self.scale_grad(grads)
grads = self._grad_sparse_indices_deduplicate(grads)
lr = self.get_lr()
if self.is_group_lr:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
@ -152,3 +153,18 @@ class ProximalAdagrad(Optimizer):
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr),
grads, params, accum)
return success
@Optimizer.target.setter
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
if value == 'CPU':
self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive", "CPU")
else:
self.sparse_opt = P.SparseApplyProximalAdagrad(self.use_locking)
self._target = value

View File

@ -345,8 +345,8 @@ class TrainStepWrap(nn.Cell):
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU")
self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU")
self.optimizer_w.target = "CPU"
self.optimizer_d.target = "CPU"
else:
self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)

View File

@ -74,7 +74,7 @@ def do_sparse_embedding(ps=False):
net.embedding.embedding_table.set_param_ps()
optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
optimizer.target = 'CPU'
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer)

View File

@ -465,7 +465,7 @@ def test_embedding_lookup_with_mix_precision():
criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
optimizer.target = 'CPU'
train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2")
train_network.set_train()
for _ in range(2):

View File

@ -109,6 +109,19 @@ def test_sparse_adam_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
optimizer.target = 'CPU'
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
def test_sparse_adam():
""" test_sparse_adam """
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -72,5 +72,19 @@ def test_spares_ftrl_compile():
net.set_train()
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
optimizer.target = 'CPU'
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
def test_spares_ftrl():
""" test sparse ftrl"""
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
optimizer.target = 'Ascend'
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -76,6 +76,20 @@ def test_spares_lazy_adam_compile():
net.set_train()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
optimizer.target = 'CPU'
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
def test_spares_lazy_adam():
""" test sparse adam"""
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
optimizer.target = 'Ascend'
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -71,6 +71,19 @@ def test_spares_proximal_ada_grad_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
optimizer.target = 'CPU'
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
def test_spares_proximal_ada_grad():
""" test sparse proximal_ada_grad """
indices = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)

View File

@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test lazy adam """
import numpy as np
from mindspore.nn.optim import LazyAdam, FTRL, Adam, ProximalAdagrad
import mindspore.nn as nn
from mindspore import Tensor, Parameter, context
from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0
self.gather = P.SparseGatherV2()
def construct(self, indices, label):
return self.gather(self.weight1, indices, self.axis) + self.weight2
def test_ftrl_target():
""" test_ftrl_target """
net = NetWithSparseGatherV2()
net.set_train()
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
if optimizer.target not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target))
def test_lazyadam_target():
""" test_lazyadam_target """
net = NetWithSparseGatherV2()
net.set_train()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
if optimizer.target not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target))
def test_adam_target():
""" test_adam_target """
net = NetWithSparseGatherV2()
net.set_train()
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
if optimizer.target not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target))
def test_proximal_target():
""" test_proximal_target """
net = NetWithSparseGatherV2()
net.set_train()
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
if optimizer.target not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(optimizer.target))