diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index 87aedfbaceb..591a314e5a5 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -23,6 +23,8 @@ from .momentum import Momentum from .adam import Adam, AdamWeightDecay, AdamOffload from .lamb import Lamb from .sgd import SGD +from .asgd import ASGD +from .rprop import Rprop from .lars import LARS from .ftrl import FTRL from .rmsprop import RMSProp @@ -33,4 +35,4 @@ from .thor import thor from .adafactor import AdaFactor __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload', - 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor'] + 'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor'] diff --git a/mindspore/nn/optim/asgd.py b/mindspore/nn/optim/asgd.py new file mode 100755 index 00000000000..55879642ef4 --- /dev/null +++ b/mindspore/nn/optim/asgd.py @@ -0,0 +1,190 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""asgd""" +from mindspore.ops import functional as F, operations as P +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +import mindspore +from mindspore._checkparam import Validator as validator +from .optimizer import Optimizer +from .optimizer import opt_init_args_register + +class ASGD(Optimizer): + r""" + Implements Average Stochastic Gradient Descent. + + Introduction to ASGD can be found at `Acceleration of stochastic approximation by average + `_. + + The updating formulas are as follows: + + \begin{gather*} + w_{t} = w_{t-1} * (1 - \lambda * \eta_{t-1}) - \eta_{t-1} * g_{t} \\ + ax_{t} = (w_t - ax_{t-1}) * \mu_{t-1} \\ + \eta_{t} = \frac{1.}{(1 + \lambda * lr * t)^\alpha} \\ + \mu_{t} = \frac{1}{\max(1, t - t0)} + \end{gather*} + + :math:`\lambda` represents the decay term, :math:`\mu` and :math:`\eta` are tracked to + update :math:`ax` and :math:`w`, :math:`t0` represents the point of starting averaging, + :math:`\alpha` represents the power for eta update, :math:`ax` represents the averaged + parameter value, :math:`t` represents the current step, :math:`g` represents `gradients`, + :math:`w` represents `params`. + + Note: + If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta' + or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters + are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be applied. + + Args: + params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the + `parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and + "order_params" are the keys can be parsed. + + - params: Required. Parameters in current group. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the optimizer will be used. + + - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value + will be used. If not, the `grad_centralization` is False by default. This configuration only works on the + convolution layer. + + - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of + parameters that appeared in the network to improve performance. The value should be parameters whose + order will be followed in optimizer. + If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in + one group of `params`. + + learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): + + - float: The fixed learning rate value. Must be equal to or greater than 0. + + - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float. + + - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. + For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate. + + - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate. + + - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of + LearningRateSchedule with step as the input to get the learning rate of current step. + + lambd (float): The decay term. Default: 1e-4. + alpha (float): The power for eta update. Default: 0.75. + t0 (float): The point of starting averaging. Default: 1e6. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + Tensor[bool], the value is True. + + Raises: + TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. + TypeError: If element of `parameters` is neither Parameter nor dict. + TypeError: If `lambd`, `alpha` or `t0` is not a float. + TypeError: If `weight_decay` is neither float nor int. + ValueError: If `weight_decay` is less than 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.ASGD(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params,'grad_centralization':True}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = nn.ASGD(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad + >>> # centralization of True. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad + >>> # centralization of False. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + + @opt_init_args_register + def __init__(self, params, learning_rate=0.1, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0.): + + super(ASGD, self).__init__(learning_rate, params, weight_decay) + + validator.check_value_type("lambd", lambd, [float], self.cls_name) + validator.check_value_type("alpha", alpha, [float], self.cls_name) + validator.check_value_type("t0", t0, [float], self.cls_name) + + self.lambd = lambd + self.alpha = alpha + self.t0 = Tensor([t0], dtype=mstype.float32) + mu, eta = [], [] + for param in self.parameters: + mu.append(Parameter(Tensor(1., dtype=mstype.float32), name='mu_'+param.name)) + eta.append(Parameter(Tensor(0., dtype=mstype.float32), name='eta_'+param.name)) + self.lens = len(self.parameters) + self.mu = mindspore.ParameterTuple(mu) + self.eta = mindspore.ParameterTuple(eta) + self.step = Parameter(Tensor(1., dtype=mstype.float32), name='step') + self.ax = self.parameters.clone(prefix="ax_", init='zeros') + self.pow = P.Pow() + self.maximum = P.Maximum() + self.assign = P.Assign() + self.assignadd = P.AssignAdd() + self.assignsub = P.AssignSub() + self.cast = P.Cast() + self.squeeze = P.Squeeze() + + def construct(self, gradients): + gradients = self.decay_weight(gradients) + gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) + lrs = self.get_lr() + success = True + + for index, (grad, param, mu, eta, ax) in enumerate(zip(gradients, self.parameters, self.mu, self.eta, self.ax)): + lr = lrs[index] if self.is_group_lr else lrs + + if self.step == 1.: + self.assign(eta, lr) + + param_fp32 = self.cast(param, mstype.float32) + gradient_fp32 = self.cast(grad, mstype.float32) + ax_fp32 = self.cast(ax, mstype.float32) + param_fp32 = param_fp32 * (1. - self.lambd * eta) - eta * gradient_fp32 + + self.assign(param, self.cast(param_fp32, param.dtype)) + + if mu != 1: + self.assignadd(ax, self.cast((param_fp32 - ax_fp32) * mu, ax.dtype)) + else: + self.assign(ax, param) + + self.assign(eta, lr / (self.pow((1. + (self.lambd * lr * self.step)), self.alpha))) + self.assign(mu, 1. / self.squeeze(self.maximum(1., self.step - self.t0))) + + success = F.depend(success, self.assignadd(self.step, 1.)) + return success diff --git a/mindspore/nn/optim/rprop.py b/mindspore/nn/optim/rprop.py new file mode 100755 index 00000000000..e159e2dc0f5 --- /dev/null +++ b/mindspore/nn/optim/rprop.py @@ -0,0 +1,215 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""rprop""" +from mindspore import ops +from mindspore.ops import functional as F, operations as P +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from .optimizer import Optimizer +from .optimizer import opt_init_args_register + +class Rprop(Optimizer): + r""" + Implements Resilient backpropagation. + + Further information about this implementation can be found at `A Direct Adaptive Method for Faster Backpropagation + Learning: The RPROP Algorithm `_. + + The updating formulas are as follows: + + .. math:: + \begin{gather*} + &\hspace{0mm} \textbf{if} \: g_{t-1} g_t > 0 \\ + &\hspace{5mm} \Delta_t \leftarrow \mathrm{min}(\Delta_{t-1} \eta_{+}, \Delta_{max}) \\ + &\hspace{0mm} \textbf{else if} \: g_{t-1} g_t < 0 \\ + &\hspace{5mm} \Delta_t \leftarrow \mathrm{max}(\Delta_{t-1} \eta_{-}, \Delta_{min}) \\ + &\hspace{mm} \textbf{else} \: \\ + &\hspace{5mm} \Delta_t \leftarrow \Delta_{t-1} \\ + &\hspace{0mm} w_{t} \leftarrow w_{t-1}- \Delta_{t} \mathrm{sign}(g_t) \\ + \end{gather*} + + :math:`\Delta_{min/max}` represents the min/max step size, :math:`\eta_{+/-}` represents the factors of + etaminus and etaplus, :math:`g` represents `gradients`, :math:`w` represents `parameters`. + + Note: + If parameters are not grouped, the `weight_decay` in optimizer will be applied on the parameters without 'beta' + or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters + are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be applied. + + Args: + params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the + `parameters` is a list of `dict`, the "params", "lr", "weight_decay", "grad_centralization" and + "order_params" are the keys can be parsed. + + - params: Required. Parameters in current group. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. + If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported. + + - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay + will be used. If not, the `weight_decay` in the optimizer will be used. + + - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value + will be used. If not, the `grad_centralization` is False by default. This configuration only works on the + convolution layer. + + - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of + parameters that appeared in the network to improve performance. The value should be parameters whose + order will be followed in optimizer. + If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in + one group of `params`. + + learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): + + - float: The fixed learning rate value. Must be equal to or greater than 0. + + - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float. + + - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. + For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate. + + - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate. + + - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of + LearningRateSchedule with step as the input to get the learning rate of current step. + + etas (tuple[float, float]): The factor of multiplicative increasing or + descreasing(etaminus, etaplus). + step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size). + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + Tensor[bool], the value is True. + + Raises: + TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. + TypeError: If element of `parameters` is neither Parameter nor dict. + TypeError: If `step_sizes` or `etas` is not a tuple. + ValueError: If maximal step size is less than minimal step size. + ValueError: If the length of `step_sizes` or `ets` is not equal to 2. + TypeError: If the element in `etas` or `step_sizes` is not a float. + ValueError: If `etaminus` is not in the range of (0, 1) or `etaplus` is not greater than 1. + TypeError: If `weight_decay` is neither float nor int. + ValueError: If `weight_decay` is less than 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = nn.Rprop(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params,'grad_centralization':True}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = nn.Rprop(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 default weight decay of 0.0 and grad + >>> # centralization of True. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad + >>> # centralization of False. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + + @opt_init_args_register + def __init__(self, params, learning_rate=0.1, etas=(0.5, 1.2), step_sizes=(1e-6, 50.), weight_decay=0.1): + + super(Rprop, self).__init__(learning_rate, params, weight_decay) + if not isinstance(etas, tuple): + raise TypeError("For Rprop, etas should be a tuple, but got {}.".format(type(etas))) + if len(etas) != 2: + raise ValueError("For Rprop, etas should be a tuple with the size of 2, but got {}.".format(len(etas))) + + if not isinstance(step_sizes, tuple): + raise TypeError("For Rprop, step_sizes should be a tuple, but got {}.".format(type(etas))) + if len(step_sizes) != 2: + raise ValueError("For Rprop, step_sizes should be a tuple with the size of 2, " + "but got {}.".format(len(step_sizes))) + + if step_sizes[0] > step_sizes[1]: + raise ValueError("For Rprop, maximal step size should not be less than minimal step size, " + "but got {} > {}.".format(step_sizes[0], step_sizes[1])) + + validator.check_float_range(etas[0], 0.0, 1.0, Rel.INC_NEITHER, "etaminus", self.cls_name) + validator.check_value_type("etaplus", etas[1], [float], self.cls_name) + if etas[1] <= 1.0: + raise ValueError("For Rprop, etaplus should be greater than 1.0, but got etaplus {}.".format(etas[1])) + + validator.check_value_type("min_step_sizes", step_sizes[0], [float], self.cls_name) + validator.check_value_type("max_step_sizes", step_sizes[1], [float], self.cls_name) + + self.etaminus, self.etaplus = etas + self.step_size_min, self.step_size_max = step_sizes + self.prev = self.parameters.clone(prefix="prev", init='zeros') + self.step_size = self.parameters.clone(prefix="step_size", init='zeros') + self.step = Parameter(Tensor(0., dtype=mstype.float32), name='step') + + self.fill = P.Fill() + self.sign = P.Sign() + self.assign = P.Assign() + self.assignadd = P.AssignAdd() + self.cast = P.Cast() + self.select = P.Select() + self.ones_like = P.OnesLike() + + def construct(self, gradients): + gradients = self.decay_weight(gradients) + gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) + lrs = self.get_lr() + success = True + + for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self.parameters, + self.prev, self.step_size)): + lr = lrs[index] if self.is_group_lr else lrs + + if self.step == 0.: + step_size_fp32 = self.ones_like(step_size) * lr + else: + step_size_fp32 = self.cast(step_size, mstype.float32) + + gradient_fp32 = self.cast(grad, mstype.float32) + param_fp32 = self.cast(param, mstype.float32) + + sign = self.sign(gradient_fp32 * prev) + sign = self.select(sign > 0, self.fill(mstype.float32, sign.shape, self.etaplus), sign) + sign = self.select(sign < 0, self.fill(mstype.float32, sign.shape, self.etaminus), sign) + sign = self.select(sign == 0, self.fill(mstype.float32, sign.shape, 1.), sign) + + step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign, self.step_size_min, self.step_size_max) + + gradient_update = self.select(sign == self.etaminus, self.fill(mstype.float32, sign.shape, 0.), + gradient_fp32) + next_param = param_fp32 - self.sign(gradient_update) * step_size_fp32 + + self.assign(param, self.cast(next_param, param.dtype)) + self.assign(prev, self.cast(gradient_update, prev.dtype)) + self.assign(step_size, self.cast(step_size_fp32, step_size.dtype)) + + success = F.depend(success, self.assignadd(self.step, 1.)) + + return success diff --git a/tests/st/optimizer/__init__.py b/tests/st/optimizer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/optimizer/optimizer_utils.py b/tests/st/optimizer/optimizer_utils.py new file mode 100644 index 00000000000..218e8fff51b --- /dev/null +++ b/tests/st/optimizer/optimizer_utils.py @@ -0,0 +1,208 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore +from mindspore import nn, Tensor +from mindspore.ops import operations as P +from mindspore.nn.optim import ASGD +from mindspore.nn.optim import Rprop +np.random.seed(1024) + +fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149, + 0.6942514, 0.39767185, 0.24918061, 0.4548748], + [0.7203382, 0.19086994, 0.76286614, 0.87920564, + 0.3169892, 0.9462494, 0.62827677, 0.27504718], + [0.3544535, 0.2524781, 0.5370583, 0.8313121, + 0.6670143, 0.0488653, 0.62225235, 0.7546456], + [0.17985944, 0.05106374, 0.31064633, 0.4863033, + 0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32") + +fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32") + +fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32") + +fc2_bias = np.array([0.09996348]).astype("float32") + + +def make_fake_data(): + """ + make fake data + """ + data, label = [], [] + for i in range(20): + data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32))) + label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32))) + return data, label + + +class NetWithLoss(nn.Cell): + """ + build net with loss + """ + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.network = network + self.loss = nn.MSELoss(reduction='sum') + + def construct(self, x, label): + out = self.network(x) + loss = self.loss(out, label) + return loss + + +class FakeNet(nn.Cell): + """ + build fake net + """ + def __init__(self): + super(FakeNet, self).__init__() + self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias)) + self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias)) + self.relu = nn.ReLU() + self.reducemean = P.ReduceMean() + + def construct(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + def _initialize_weights(self): + """ + parameter initialization + """ + self.init_parameters_data() + for name, m in self.cells_and_names(): + if name == 'fc1': + m.weight.set_data(Tensor(fc1_weight)) + m.bias.set_data(Tensor(fc1_bias)) + elif name == 'fc2': + m.weight.set_data(Tensor(fc2_weight)) + m.bias.set_data(Tensor(fc2_bias)) + + +def build_network(opt_config, is_group=False): + """ + Construct training + """ + losses = [] + net = FakeNet() + + networkwithloss = NetWithLoss(net) + networkwithloss.set_train() + + if is_group: + fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params())) + fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params())) + if opt_config['name'] == 'ASGD': + params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}] + else: + params = [{'params': fc1_params, 'lr': 0.001}, {'params': fc2_params, 'lr': 0.1}] + else: + params = networkwithloss.trainable_params() + + if opt_config['name'] == 'ASGD': + net_opt = ASGD(params, learning_rate=opt_config['lr'], lambd=opt_config['lambd'], alpha=opt_config['alpha'], + t0=opt_config['t0'], weight_decay=opt_config['weight_decay']) + + elif opt_config['name'] == 'Rprop': + net_opt = Rprop(params, learning_rate=opt_config['lr'], etas=opt_config['etas'], + step_sizes=opt_config['step_sizes'], weight_decay=0.0) + + trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt) + data, label = make_fake_data() + for i in range(20): + loss = trainonestepcell(data[i], label[i]) + losses.append(loss.asnumpy()) + if opt_config['name'] == 'ASGD': + return np.array(losses), net_opt + return np.array(losses) + + +loss_default_asgd = np.array([3.01246792e-01, 1.20041794e+02, 1.38681079e+03, 2.01250820e+01, + 3.27283554e+01, 4.76963005e+01, 6.47094269e+01, 8.34786530e+01, + 1.03742706e+02, 1.25265739e+02, 1.47835190e+02, 1.71259613e+02, + 1.95367035e+02, 2.20003204e+02, 2.45029831e+02, 2.70323456e+02, + 2.95774048e+02, 3.21283752e+02, 3.46765594e+02, 3.72143097e+02], dtype=np.float32) + +loss_not_default_asgd = np.array([3.01246792e-01, 1.26019104e+02, 1.90600449e+02, 9.70605755e+00, + 2.98419113e+01, 3.68430023e+02, 1.06318066e+04, 1.35017746e+02, + 1.68673813e+02, 2.05914215e+02, 2.46694992e+02, 2.90972443e+02, + 3.38703430e+02, 3.89845123e+02, 4.44355103e+02, 5.02191406e+02, + 5.63312500e+02, 6.27676941e+02, 6.95244202e+02, 7.65973816e+02], dtype=np.float32) + +loss_group_asgd = np.array([3.01246792e-01, 7.26708527e+01, 2.84905312e+05, 4.17499258e+04, + 1.46797949e+04, 5.07966602e+03, 1.70935132e+03, 5.47094910e+02, + 1.59216995e+02, 3.78818207e+01, 5.18196869e+00, 2.62275129e-03, + 2.09768796e+00, 5.23108435e+00, 7.78943682e+00, 9.57108879e+00, + 1.07310610e+01, 1.14618425e+01, 1.19147835e+01, 1.21936722e+01], dtype=np.float32) + + +loss_default_rprop = np.array([3.01246792e-01, 1.19871742e+02, 4.13467163e+02, 8.09146179e+02, + 1.22364807e+03, 1.56787573e+03, 1.75733594e+03, 1.72866272e+03, + 1.46183936e+03, 1.00406335e+03, 4.84076874e+02, 9.49734650e+01, + 2.00592804e+01, 1.87920704e+01, 1.53733969e+01, 1.85836582e+01, + 5.21527790e-02, 2.01522671e-02, 7.19913816e+00, 8.52459526e+00], dtype=np.float32) + +loss_not_default_rprop = np.array([3.0124679e-01, 1.2600269e+02, 4.7351608e+02, 1.0220379e+03, + 1.7181555e+03, 2.4367019e+03, 2.9170872e+03, 2.7243464e+03, + 1.4999669e+03, 7.5820435e+01, 1.0590715e+03, 5.4336096e+02, + 7.0162407e+01, 8.2754419e+02, 9.6329260e+02, 3.4475109e+01, + 5.3843134e+02, 6.0064526e+02, 1.1046149e+02, 3.5530117e+03], dtype=np.float32) + +loss_group_rprop = np.array([3.0124679e-01, 7.1360558e+01, 4.8910957e+01, 2.1730331e+02, + 3.0747052e+02, 5.2734237e+00, 5.6865869e+00, 1.7116127e+02, + 2.0539343e+02, 2.2993685e+01, 2.6194101e+02, 2.8772815e+02, + 2.4236647e+01, 3.9299741e+02, 3.5600668e+02, 1.4759110e+01, + 7.2244568e+02, 8.1952783e+02, 9.8913864e+01, 1.1141744e+03], dtype=np.float32) + + +default_fc1_weight_asgd = np.array([[-0.9451941, -0.71258026, -1.2602371, -1.4823773, + -0.974408, -1.2709816, -1.4194703, -1.2137808], + [-1.5341775, -2.0636342, -1.4916497, -1.3753126, + -1.9375193, -1.308271, -1.6262367, -1.9794592], + [-1.9886293, -2.0906024, -1.8060291, -1.5117803, + -1.6760755, -2.2942104, -1.7208353, -1.5884445], + [-2.071215, -2.2000103, -1.9404325, -1.7647781, + -1.4022746, -1.6987679, -2.0481179, -1.5297506]], dtype=np.float32) +default_fc1_bias_asgd = np.array([-0.17978168, -1.0764512, -0.578816, -0.2928958], dtype=np.float32) +default_fc2_weight_asgd = np.array([[4.097412, 6.2694297, 5.9203916, 5.3845487]], dtype=np.float32) +default_fc2_bias_asgd = np.array([6.904814], dtype=np.float32) + + +no_default_fc1_weight_asgd = np.array([[-1.3406217, -1.1080127, -1.655658, -1.8777936, + -1.3698348, -1.6664025, -1.8148884, -1.6092018], + [-1.1475986, -1.6770473, -1.1050745, -0.98873824, + -1.5509329, -0.9216978, -1.2396574, -1.5928726], + [-1.2329121, -1.334883, -1.050313, -0.756071, + -0.92036265, -1.5384867, -0.96512324, -0.8327349], + [-1.0685704, -1.1973612, -0.9377885, -0.7621386, + -0.39964262, -0.69612867, -1.0454736, -0.52711576]], dtype=np.float32) +no_default_fc1_bias_asgd = np.array([0.41264832, -0.19961096, 0.37743938, 0.65807366], dtype=np.float32) +no_default_fc2_weight_asgd = np.array([[-5.660916, -5.9415145, -5.1402636, -4.199707]], dtype=np.float32) +no_default_fc2_bias_asgd = np.array([0.5082278], dtype=np.float32) + + +no_default_group_fc1_weight_asgd = np.array([[-32.526627, -32.29401, -32.8416, -33.06367, -32.55584, + -32.852345, -33.000767, -32.795143], + [-33.164936, -33.69432, -33.12241, -33.006073, -33.568207, + -32.9391, -33.256996, -33.61015], + [-33.118973, -33.220943, -32.936436, -32.642193, -32.806488, + -33.424484, -32.85125, -32.718857], + [-30.155754, -30.284513, -30.025005, -29.849358, -29.486917, + -29.783375, -30.132658, -29.614393]], dtype=np.float32) +no_default_group_fc1_bias_asgd = np.array([-15.838092, -16.811989, -16.078112, -14.289094], dtype=np.float32) +no_default_group_fc2_weight_asgd = np.array([[1288.7146, 1399.3041, 1292.8445, 1121.4629]], dtype=np.float32) +no_default_group_fc2_bias_asgd = np.array([18.513494], dtype=np.float32) diff --git a/tests/st/optimizer/test_asgd_ascend.py b/tests/st/optimizer/test_asgd_ascend.py new file mode 100644 index 00000000000..dda14721830 --- /dev/null +++ b/tests/st/optimizer/test_asgd_ascend.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.context as context +from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd + + +def test_default_asgd_pynative(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Pynative mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_asgd, \ + default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_default_asgd, loss, atol=1.e-5) + + +def test_default_asgd_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_asgd, \ + default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_default_asgd, loss, atol=1.e-5) + + +def test_no_default_asgd_pynative(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Pynative mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_fc1_weight_asgd, \ + no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3) + + +def test_no_default_asgd_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_fc1_weight_asgd, \ + no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3) + + +def test_default_asgd_group_pynative(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Pynative mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ + no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config, is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3) + + +def test_default_asgd_group_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ + no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config, is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3) diff --git a/tests/st/optimizer/test_asgd_cpu.py b/tests/st/optimizer/test_asgd_cpu.py new file mode 100644 index 00000000000..604f2e5bf2c --- /dev/null +++ b/tests/st/optimizer/test_asgd_cpu.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.context as context +from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd + + +def test_default_asgd_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_asgd, \ + default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_default_asgd, loss, atol=1.e-5) + + +def test_no_default_asgd_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_fc1_weight_asgd, \ + no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config, is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3) + + + +def test_default_asgd_group_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ + no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config, is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3) diff --git a/tests/st/optimizer/test_asgd_gpu.py b/tests/st/optimizer/test_asgd_gpu.py new file mode 100644 index 00000000000..78b3531ef83 --- /dev/null +++ b/tests/st/optimizer/test_asgd_gpu.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.context as context +from .optimizer_utils import build_network, loss_not_default_asgd, loss_default_asgd, loss_group_asgd + + +def test_default_asgd_pynative(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Pynative mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_asgd, \ + default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_default_asgd, loss, atol=1.e-5) + + +def test_default_asgd_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import default_fc1_weight_asgd, \ + default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_default_asgd, loss, atol=1.e-5) + + +def test_no_default_asgd_pynative(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Pynative mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_fc1_weight_asgd, \ + no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3) + + +def test_no_default_asgd_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_fc1_weight_asgd, \ + no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config) + assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_not_default_asgd, loss, atol=1.e-5, rtol=1e-3) + + +def test_default_asgd_group_pynative(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Pynative mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ + no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config, is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3) + + +def test_default_asgd_group_graph(): + """ + Feature: Test ASGD optimizer + Description: Test ASGD in Graph mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \ + no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001} + loss, cells = build_network(config, is_group=True) + assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-5) + assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-5) + assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-5) + assert np.allclose(loss_group_asgd, loss, atol=1.e-5, rtol=1e-3) diff --git a/tests/st/optimizer/test_rprop_ascend.py b/tests/st/optimizer/test_rprop_ascend.py new file mode 100644 index 00000000000..e03bca4f387 --- /dev/null +++ b/tests/st/optimizer/test_rprop_ascend.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop + + +def test_default_rprop_pynative(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Pynative mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_default_rprop, loss, atol=1.e-5) + + +def test_default_rprop_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_default_rprop, loss, atol=1.e-5) + + +def test_no_default_rprop_pynative(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Pynative mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5) + + +def test_no_default_rprop_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5) + + +def test_default_rprop_group_pynative(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Pynative mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} + loss = build_network(config, is_group=True) + assert np.allclose(loss_group_rprop, loss, atol=1.e-5) + + +def test_default_rprop_group_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} + loss = build_network(config, is_group=True) + assert np.allclose(loss_group_rprop, loss, atol=1.e-5) diff --git a/tests/st/optimizer/test_rprop_cpu.py b/tests/st/optimizer/test_rprop_cpu.py new file mode 100644 index 00000000000..58224e4b367 --- /dev/null +++ b/tests/st/optimizer/test_rprop_cpu.py @@ -0,0 +1,54 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop + + +def test_default_rprop_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_default_rprop, loss, atol=1.e-5) + + +def test_no_default_rprop_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5) + + +def test_default_rprop_group_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} + loss = build_network(config, is_group=True) + assert np.allclose(loss_group_rprop, loss, atol=1.e-5) diff --git a/tests/st/optimizer/test_rprop_gpu.py b/tests/st/optimizer/test_rprop_gpu.py new file mode 100644 index 00000000000..49ba45ea7a2 --- /dev/null +++ b/tests/st/optimizer/test_rprop_gpu.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +from .optimizer_utils import build_network, loss_default_rprop, loss_group_rprop, loss_not_default_rprop + + +def test_default_rprop_pynative(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Pynative mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_default_rprop, loss, atol=1.e-5) + + +def test_default_rprop_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with default parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + config = {'name': 'Rprop', 'lr': 0.01, 'etas': (0.5, 1.2), 'step_sizes': (1e-6, 50.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_default_rprop, loss, atol=1.e-5) + + +def test_no_default_rprop_pynative(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Pynative mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5) + + +def test_no_default_rprop_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with another set of parameter + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-3, 20.), 'weight_decay': 0.0} + loss = build_network(config) + assert np.allclose(loss_not_default_rprop, loss, atol=1.e-5) + + +def test_default_rprop_group_pynative(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Pynative mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} + loss = build_network(config, is_group=True) + assert np.allclose(loss_group_rprop, loss, atol=1.e-5) + + +def test_default_rprop_group_graph(): + """ + Feature: Test Rprop optimizer + Description: Test Rprop in Graph mode with parameter grouping + Expectation: Loss values and parameters conform to preset values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + config = {'name': 'Rprop', 'lr': 0.001, 'etas': (0.6, 1.9), 'step_sizes': (1e-2, 10.), 'weight_decay': 0.0} + loss = build_network(config, is_group=True) + assert np.allclose(loss_group_rprop, loss, atol=1.e-5)