!13341 Add MomentumWithWeightDecayScale kernel-2

From: @VectorSL
Reviewed-by: @kingxian,@chujinjin
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-03-22 09:46:40 +08:00 committed by Gitee
commit f4a5eb5219
4 changed files with 228 additions and 3 deletions

View File

@ -41,7 +41,8 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey,
FusedWeightScaleApplyMomentum)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr,
@ -162,6 +163,7 @@ __all__ = [
'NLLLoss',
'SGD',
'ApplyMomentum',
'FusedWeightScaleApplyMomentum',
'ExpandDims',
'Cast',
'IsSubClass',

View File

@ -21,7 +21,7 @@ from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
from .. import signature as sig
class ScalarCast(PrimitiveWithInfer):
"""
@ -374,3 +374,70 @@ class MakeRefKey(Primitive):
def __call__(self):
pass
class FusedWeightScaleApplyMomentum(PrimitiveWithInfer):
"""
Optimizer that implements the Momentum algorithm with weight decay and loss scale.
Refer to the paper `On the importance of initialization and momentum in deep
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage.
Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
Data type conversion of Parameter is not supported. RuntimeError exception will be thrown.
Inputs:
- **weight_decay** (Tensor) - The weight decay value, must be a scalar tensor with float data type.
Default: 0.0.
- **loss_scale** (Tensor) - The loss scale value, must be a scalar tensor with float data type.
Default: 1.0.
- **variable** (Parameter) - Weights to be updated. data type must be float.
- **accumulation** (Parameter) - Accumulated gradient value by moment weight.
Has the same data type with `variable`.
- **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float number or
a scalar tensor with float data type.
- **gradient** (Tensor) - Gradient, has the same data type as `variable`.
- **momentum** (Union[Number, Tensor]) - Momentum, must be a float number or
a scalar tensor with float data type.
Outputs:
Tensor, parameters to be updated.
Supported Platforms:
``GPU``
Examples:
Please refer to the usage in :class:`mindspore.nn.Momentum`, and add weight_decay and loss_scale as inputs.
"""
__mindspore_signature__ = (
sig.make_sig('weight_decay', dtype=sig.sig_dtype.T3),
sig.make_sig('loss_scale', dtype=sig.sig_dtype.T3),
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1),
sig.make_sig('gradient', dtype=sig.sig_dtype.T),
sig.make_sig('momentum', dtype=sig.sig_dtype.T2)
)
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['weight_decay', 'loss_scale', 'variable', 'accumulation', 'learning_rate',
'gradient', 'momentum'], outputs=['output'])
def infer_shape(self, d_shape, s_shape, v_shape, a_shape, l_shape, g_shape, m_shape):
return v_shape
def infer_dtype(self, d_dtype, s_dtype, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
valid_dtypes = [mstype.float16, mstype.float32]
if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name)
return v_dtype

View File

@ -33,6 +33,7 @@ import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
from src.resnet_gpu_benchmark import resnet50 as resnet
from src.CrossEntropySmooth import CrossEntropySmooth
from src.momentum import Momentum as MomentumWeightDecay
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.')
@ -228,7 +229,10 @@ def train():
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Mixed precision
if compute_type == "fp16":
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024)
if mode == context.PYNATIVE_MODE:
opt = MomentumWeightDecay(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024)
else:
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
# define callbacks

View File

@ -0,0 +1,152 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""momentum"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, weight_decay, scale, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = F.depend(True, opt(weight_decay, scale, weight, moment, learning_rate, gradient, momentum))
return success
class Momentum(Optimizer):
r"""
Implements the Momentum algorithm.
Refer to the paper on the importance of initialization and momentum in deep learning for more details.
.. math::
v_{t} = v_{t-1} \ast u + gradients
If use_nesterov is True:
.. math::
p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr)
If use_nesterov is Flase:
.. math::
p_{t} = p_{t-1} - lr \ast v_{t}
Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
in the value of 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
momentum (float): Hyperparameter of type float, means momentum for the moving average.
It must be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
loss_scale (int, float): A floating point value for the loss scale. It must be greater than 0.0. Default: 1.0.
use_nesterov (bool): Enable Nesterov momentum. Default: False.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Raises:
ValueError: If the momentum is less than 0.0.
TypeError: If the momentum is not a float or use_nesterov is not a bool.
Supported Platforms:
``GPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
>>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01.
>>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
# Use FusedWeightScaleApplyMomentum to avoid extra kernel launch.
self.opt = P.FusedWeightScaleApplyMomentum()
def construct(self, gradients):
params = self.params
moments = self.moments
weight_decay = Tensor(0.0, mstype.float32)
scale = Tensor(1.0, mstype.float32)
if self.exec_weight_decay:
weight_decay = self.weight_decay_tensor
if self.need_scale:
scale = self.reciprocal_scale
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, weight_decay, scale, self.momentum),
lr, gradients, params, moments)
else:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, weight_decay, scale, self.momentum, lr),
gradients, params, moments)
return success