!9069 register fusion operator for lamb optimizer
From: @shibeiji Reviewed-by: @c_34,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
142f9c2d3e
|
@ -202,6 +202,8 @@ constexpr const char kNameCase[] = "Case";
|
|||
constexpr const char kNameAssert[] = "Assert";
|
||||
constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
|
||||
constexpr const char kNameReverseV2[] = "ReverseV2";
|
||||
constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign";
|
||||
constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign";
|
||||
|
||||
class OpAdapterMap {
|
||||
public:
|
||||
|
|
|
@ -362,4 +362,24 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
|||
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Atan2, kNameAtan2, ADPT_DESC(Atan2))
|
||||
|
||||
// LambApplyOptimizerAssign
|
||||
INPUT_MAP(LambApplyOptimizerAssign) = {
|
||||
{1, INPUT_DESC(grad)}, {2, INPUT_DESC(inputv)}, {3, INPUT_DESC(inputm)},
|
||||
{4, INPUT_DESC(input3)}, {5, INPUT_DESC(mul0_x)}, {6, INPUT_DESC(mul1_x)},
|
||||
{7, INPUT_DESC(mul2_x)}, {8, INPUT_DESC(mul3_x)}, {9, INPUT_DESC(add2_y)},
|
||||
{10, INPUT_DESC(steps)}, {11, INPUT_DESC(do_use_weight)}, {12, INPUT_DESC(weight_decay_rate)}};
|
||||
ATTR_MAP(LambApplyOptimizerAssign) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(LambApplyOptimizerAssign) = {{0, OUTPUT_DESC(output0)}, {1, OUTPUT_DESC(inputv)}, {2, OUTPUT_DESC(inputm)}};
|
||||
REG_ADPT_DESC(LambApplyOptimizerAssign, kNameLambApplyOptimizerAssign, ADPT_DESC(LambApplyOptimizerAssign))
|
||||
|
||||
// LambApplyWeightAssign
|
||||
INPUT_MAP(LambApplyWeightAssign) = {{1, INPUT_DESC(input0)},
|
||||
{2, INPUT_DESC(input1)},
|
||||
{3, INPUT_DESC(input2)},
|
||||
{4, INPUT_DESC(input3)},
|
||||
{5, INPUT_DESC(input_param)}};
|
||||
ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}};
|
||||
REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -189,5 +189,11 @@ DECLARE_OP_USE_OUTPUT(Round)
|
|||
|
||||
DECLARE_OP_ADAPTER(Atan2)
|
||||
DECLARE_OP_USE_OUTPUT(Atan2)
|
||||
|
||||
DECLARE_OP_ADAPTER(LambApplyOptimizerAssign)
|
||||
DECLARE_OP_USE_OUTPUT(LambApplyOptimizerAssign)
|
||||
|
||||
DECLARE_OP_ADAPTER(LambApplyWeightAssign)
|
||||
DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_
|
||||
|
|
|
@ -111,6 +111,52 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
|
|||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
|
||||
|
||||
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
|
||||
optim_filter):
|
||||
"""
|
||||
Update parameters function when device target is ascend.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
||||
global_step (Tensor): Global step.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Specifies whether param update with weight decay.
|
||||
optim_filter(bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_cast = P.Cast()
|
||||
op_norm = layer.Norm()
|
||||
op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
|
||||
op_lamb_apply_weight_assign = P.LambApplyWeightAssign()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
new_global_step = op_cast(global_step + num_one, mstype.float32)
|
||||
weight_decay_flag = op_cast(decay_flag, mstype.float32)
|
||||
|
||||
update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32,
|
||||
beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps,
|
||||
new_global_step, weight_decay_flag, weight_decay)
|
||||
w_norm = op_norm(param_fp32)
|
||||
g_norm = op_norm(update)
|
||||
update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param))
|
||||
return update
|
||||
return gradient
|
||||
|
||||
|
||||
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
|
||||
|
||||
|
@ -279,6 +325,7 @@ class Lamb(Optimizer):
|
|||
self.hyper_map = C.HyperMap()
|
||||
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("enable_graph_kernel")
|
||||
self.device_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
|
@ -299,19 +346,20 @@ class Lamb(Optimizer):
|
|||
self.global_step, lr, self.weight_decay),
|
||||
self.params, self.moments1, self.moments2, gradients, self.decay_flags)
|
||||
else:
|
||||
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step),
|
||||
lr, self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr),
|
||||
self.weight_decay, self.params, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
|
||||
self.global_step, lr, self.weight_decay),
|
||||
self.params, self.moments1, self.moments2, gradients,
|
||||
self.decay_flags, self.optim_filter)
|
||||
|
|
|
@ -351,3 +351,5 @@ from .conv3d import _conv3d_tbe
|
|||
from .conv3d_backprop_input import _conv3d_backprop_input_tbe
|
||||
from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe
|
||||
from .conv3d_transpose import _conv3d_transpose_tbe
|
||||
from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe
|
||||
from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LambApplyOptimizerAssign op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lamb_apply_optimizer_assign_op_info = TBERegOp("LambApplyOptimizerAssign") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lamb_apply_optimizer_assign.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lamb_apply_optimizer_assign") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "grad", False, "required", "all") \
|
||||
.input(1, "inputv", False, "required", "all") \
|
||||
.input(2, "inputm", False, "required", "all") \
|
||||
.input(3, "input3", False, "required", "all") \
|
||||
.input(4, "mul0_x", False, "required", "all") \
|
||||
.input(5, "mul1_x", False, "required", "all") \
|
||||
.input(6, "mul2_x", False, "required", "all") \
|
||||
.input(7, "mul3_x", False, "required", "all") \
|
||||
.input(8, "add2_y", False, "required", "all") \
|
||||
.input(9, "steps", False, "required", "all") \
|
||||
.input(10, "do_use_weight", False, "required", "all") \
|
||||
.input(11, "weight_decay_rate", False, "required", "all") \
|
||||
.output(0, "output0", False, "required", "all") \
|
||||
.output(0, "inputv", False, "required", "all") \
|
||||
.output(0, "inputm", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lamb_apply_optimizer_assign_op_info)
|
||||
def _lamb_apply_optimizer_assign_tbe():
|
||||
"""LambApplyOptimizerAssign TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LambApplyWeightAssign op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lamb_apply_weight_assign_op_info = TBERegOp("LambApplyWeightAssign") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lamb_apply_weight_assign.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lamb_apply_weight_assign") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input0", False, "required", "all") \
|
||||
.input(1, "input1", False, "required", "all") \
|
||||
.input(2, "input2", False, "required", "all") \
|
||||
.input(3, "input3", False, "required", "all") \
|
||||
.input(4, "input_param", False, "required", "all") \
|
||||
.output(0, "input_param", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lamb_apply_weight_assign_op_info)
|
||||
def _lamb_apply_weight_assign_tbe():
|
||||
"""LambApplyWeightAssign TBE register"""
|
||||
return
|
|
@ -112,3 +112,15 @@ class LambUpdateWithLR:
|
|||
class LambNextMV:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LambApplyOptimizerAssign:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LambApplyWeightAssign:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
|
|
@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
|
|||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram
|
||||
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign
|
||||
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
||||
|
|
|
@ -172,3 +172,132 @@ class NoRepeatNGram(PrimitiveWithInfer):
|
|||
valid_values = (mstype.float16, mstype.float32, mstype.float64)
|
||||
validator.check_type_name("log_type", log_type, valid_values, self.name)
|
||||
return log_type
|
||||
|
||||
|
||||
class LambApplyOptimizerAssign(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates gradients by LAMB optimizer algorithm. Get the compute ratio.
|
||||
|
||||
The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
|
||||
<https://arxiv.org/abs/1904.00962>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
m = \frac{m}{1 - \beta_1^t} \\
|
||||
v = \frac{v}{1 - \beta_2^t} \\
|
||||
r = \frac{m}{\sqrt{v} + \epsilon} \\
|
||||
w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w))
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
|
||||
`beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents
|
||||
`epsilon`.
|
||||
|
||||
Inputs:
|
||||
- **gradient** (Tensor) - Gradient of parameters, float32/float16.
|
||||
- **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`.
|
||||
- **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`.
|
||||
- **var** (Tensor) - Weights to be updated, has the same type as `gradient`.
|
||||
- **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16.
|
||||
- **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`.
|
||||
- **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`.
|
||||
- **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`.
|
||||
- **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`.
|
||||
- **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`.
|
||||
- **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`.
|
||||
- **decay_flag** (Tensor) -Specify whether param upadte with weight decay, has the same type as `beta1`.
|
||||
- **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the compute ratio r.
|
||||
- **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`.
|
||||
- **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace,
|
||||
has the same type as `gradient`.
|
||||
- **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace,
|
||||
has the same type as `gradient`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize LambApplyOptimizerAssign"""
|
||||
|
||||
def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
|
||||
beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
|
||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
||||
return m_shape, v_shape, m_shape
|
||||
|
||||
def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
|
||||
beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype,
|
||||
"eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype,
|
||||
"weight_decay": weight_decay_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
|
||||
return m_dtype, v_dtype, v_dtype
|
||||
|
||||
|
||||
class LambApplyWeightAssign(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates gradients by LAMB optimizer algorithm. The weight update part.
|
||||
|
||||
The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
|
||||
<https://arxiv.org/abs/1904.00962>`_.
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
m = \beta_1 * m + (1 - \beta_1) * g \\
|
||||
v = \beta_2 * v + (1 - \beta_2) * g * g \\
|
||||
m = \frac{m}{1 - \beta_1^t} \\
|
||||
v = \frac{v}{1 - \beta_2^t} \\
|
||||
r = \frac{m}{\sqrt{v} + \epsilon} \\
|
||||
w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w))
|
||||
\end{array}
|
||||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
|
||||
`beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents
|
||||
`epsilon`.
|
||||
|
||||
Inputs:
|
||||
- **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16.
|
||||
- **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`.
|
||||
- **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16.
|
||||
- **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16.
|
||||
- **var** (Tensor) - Weights to be updated, the same shape and type as `update`.
|
||||
|
||||
Outputs:
|
||||
- **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize LambApplyWeightAssign"""
|
||||
|
||||
def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape):
|
||||
validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name)
|
||||
return var_shape
|
||||
|
||||
def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype):
|
||||
args = {"var": var_dtype, "update": update_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
|
||||
return var_dtype
|
||||
|
|
|
@ -229,7 +229,7 @@ def test_bert_performance():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [10.235566, 10.207392, 10.206976]
|
||||
expect_loss_value = [11.325791, 11.285011, 11.284766]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
|
@ -239,7 +239,7 @@ def test_bert_performance():
|
|||
assert (overflow == expect_overflow).all()
|
||||
|
||||
loss_scale = np.array(callback.lossscale_list)
|
||||
expect_loss_scale = [262144.0, 262144.0, 262144.0]
|
||||
expect_loss_scale = [65536.0, 65536.0, 65536.0]
|
||||
print("loss scale: {}".format(loss_scale))
|
||||
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
|
||||
|
||||
|
|
|
@ -225,8 +225,12 @@ def test_bert_percision(enable_graph_kernel=False):
|
|||
loss_value = np.array(callback.loss_list)
|
||||
assert np.allclose(loss_value[0], 12.2065868, 0, 0.000001)
|
||||
|
||||
expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466,
|
||||
12.6212320, 12.2229223, 12.4272099]
|
||||
if enable_graph_kernel:
|
||||
expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466,
|
||||
12.6212320, 12.2229223, 12.4272099]
|
||||
else:
|
||||
expect_loss_value = [12.2065868, 11.94102, 11.931558, 11.938105, 11.932648, 12.556579, 12.130686, 12.783716,
|
||||
12.360179, 12.578461]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
|
|
Loading…
Reference in New Issue