!9069 register fusion operator for lamb optimizer

From: @shibeiji
Reviewed-by: @c_34,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-12-31 09:46:25 +08:00 committed by Gitee
commit 142f9c2d3e
12 changed files with 328 additions and 8 deletions

View File

@ -202,6 +202,8 @@ constexpr const char kNameCase[] = "Case";
constexpr const char kNameAssert[] = "Assert"; constexpr const char kNameAssert[] = "Assert";
constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder"; constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
constexpr const char kNameReverseV2[] = "ReverseV2"; constexpr const char kNameReverseV2[] = "ReverseV2";
constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign";
constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign";
class OpAdapterMap { class OpAdapterMap {
public: public:

View File

@ -362,4 +362,24 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Atan2, kNameAtan2, ADPT_DESC(Atan2)) REG_ADPT_DESC(Atan2, kNameAtan2, ADPT_DESC(Atan2))
// LambApplyOptimizerAssign
INPUT_MAP(LambApplyOptimizerAssign) = {
{1, INPUT_DESC(grad)}, {2, INPUT_DESC(inputv)}, {3, INPUT_DESC(inputm)},
{4, INPUT_DESC(input3)}, {5, INPUT_DESC(mul0_x)}, {6, INPUT_DESC(mul1_x)},
{7, INPUT_DESC(mul2_x)}, {8, INPUT_DESC(mul3_x)}, {9, INPUT_DESC(add2_y)},
{10, INPUT_DESC(steps)}, {11, INPUT_DESC(do_use_weight)}, {12, INPUT_DESC(weight_decay_rate)}};
ATTR_MAP(LambApplyOptimizerAssign) = EMPTY_ATTR_MAP;
OUTPUT_MAP(LambApplyOptimizerAssign) = {{0, OUTPUT_DESC(output0)}, {1, OUTPUT_DESC(inputv)}, {2, OUTPUT_DESC(inputm)}};
REG_ADPT_DESC(LambApplyOptimizerAssign, kNameLambApplyOptimizerAssign, ADPT_DESC(LambApplyOptimizerAssign))
// LambApplyWeightAssign
INPUT_MAP(LambApplyWeightAssign) = {{1, INPUT_DESC(input0)},
{2, INPUT_DESC(input1)},
{3, INPUT_DESC(input2)},
{4, INPUT_DESC(input3)},
{5, INPUT_DESC(input_param)}};
ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP;
OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}};
REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -189,5 +189,11 @@ DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(Atan2) DECLARE_OP_ADAPTER(Atan2)
DECLARE_OP_USE_OUTPUT(Atan2) DECLARE_OP_USE_OUTPUT(Atan2)
DECLARE_OP_ADAPTER(LambApplyOptimizerAssign)
DECLARE_OP_USE_OUTPUT(LambApplyOptimizerAssign)
DECLARE_OP_ADAPTER(LambApplyWeightAssign)
DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_

View File

@ -111,6 +111,52 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
return op_cast(next_param, F.dtype(param)) return op_cast(next_param, F.dtype(param))
return gradient return gradient
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
optim_filter):
"""
Update parameters function when device target is ascend.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
global_step (Tensor): Global step.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
if optim_filter:
op_cast = P.Cast()
op_norm = layer.Norm()
op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
op_lamb_apply_weight_assign = P.LambApplyWeightAssign()
param_fp32 = op_cast(param, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
new_global_step = op_cast(global_step + num_one, mstype.float32)
weight_decay_flag = op_cast(decay_flag, mstype.float32)
update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32,
beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps,
new_global_step, weight_decay_flag, weight_decay)
w_norm = op_norm(param_fp32)
g_norm = op_norm(update)
update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param))
return update
return gradient
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
@ -279,6 +325,7 @@ class Lamb(Optimizer):
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
context.get_context("enable_graph_kernel") context.get_context("enable_graph_kernel")
self.device_ascend = context.get_context("device_target") == "Ascend"
def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()
@ -299,19 +346,20 @@ class Lamb(Optimizer):
self.global_step, lr, self.weight_decay), self.global_step, lr, self.weight_decay),
self.params, self.moments1, self.moments2, gradients, self.decay_flags) self.params, self.moments1, self.moments2, gradients, self.decay_flags)
else: else:
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
if self.is_group: if self.is_group:
if self.is_group_lr: if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step), self.global_step),
lr, self.weight_decay, self.params, self.moments1, self.moments2, lr, self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter) gradients, self.decay_flags, self.optim_filter)
else: else:
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr), self.global_step, lr),
self.weight_decay, self.params, self.moments1, self.moments2, self.weight_decay, self.params, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter) gradients, self.decay_flags, self.optim_filter)
else: else:
optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
self.global_step, lr, self.weight_decay), self.global_step, lr, self.weight_decay),
self.params, self.moments1, self.moments2, gradients, self.params, self.moments1, self.moments2, gradients,
self.decay_flags, self.optim_filter) self.decay_flags, self.optim_filter)

View File

@ -351,3 +351,5 @@ from .conv3d import _conv3d_tbe
from .conv3d_backprop_input import _conv3d_backprop_input_tbe from .conv3d_backprop_input import _conv3d_backprop_input_tbe
from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe
from .conv3d_transpose import _conv3d_transpose_tbe from .conv3d_transpose import _conv3d_transpose_tbe
from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe
from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe

View File

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LambApplyOptimizerAssign op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
lamb_apply_optimizer_assign_op_info = TBERegOp("LambApplyOptimizerAssign") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("lamb_apply_optimizer_assign.so") \
.compute_cost(10) \
.kernel_name("lamb_apply_optimizer_assign") \
.partial_flag(True) \
.input(0, "grad", False, "required", "all") \
.input(1, "inputv", False, "required", "all") \
.input(2, "inputm", False, "required", "all") \
.input(3, "input3", False, "required", "all") \
.input(4, "mul0_x", False, "required", "all") \
.input(5, "mul1_x", False, "required", "all") \
.input(6, "mul2_x", False, "required", "all") \
.input(7, "mul3_x", False, "required", "all") \
.input(8, "add2_y", False, "required", "all") \
.input(9, "steps", False, "required", "all") \
.input(10, "do_use_weight", False, "required", "all") \
.input(11, "weight_decay_rate", False, "required", "all") \
.output(0, "output0", False, "required", "all") \
.output(0, "inputv", False, "required", "all") \
.output(0, "inputm", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(lamb_apply_optimizer_assign_op_info)
def _lamb_apply_optimizer_assign_tbe():
"""LambApplyOptimizerAssign TBE register"""
return

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LambApplyWeightAssign op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
lamb_apply_weight_assign_op_info = TBERegOp("LambApplyWeightAssign") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("lamb_apply_weight_assign.so") \
.compute_cost(10) \
.kernel_name("lamb_apply_weight_assign") \
.partial_flag(True) \
.input(0, "input0", False, "required", "all") \
.input(1, "input1", False, "required", "all") \
.input(2, "input2", False, "required", "all") \
.input(3, "input3", False, "required", "all") \
.input(4, "input_param", False, "required", "all") \
.output(0, "input_param", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(lamb_apply_weight_assign_op_info)
def _lamb_apply_weight_assign_tbe():
"""LambApplyWeightAssign TBE register"""
return

View File

@ -112,3 +112,15 @@ class LambUpdateWithLR:
class LambNextMV: class LambNextMV:
def __call__(self, *args): def __call__(self, *args):
pass pass
@op_selector
class LambApplyOptimizerAssign:
def __call__(self, *args):
pass
@op_selector
class LambApplyWeightAssign:
def __call__(self, *args):
pass

View File

@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert) TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import ControlDepend, GeSwitch, Merge from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,

View File

@ -172,3 +172,132 @@ class NoRepeatNGram(PrimitiveWithInfer):
valid_values = (mstype.float16, mstype.float32, mstype.float64) valid_values = (mstype.float16, mstype.float32, mstype.float64)
validator.check_type_name("log_type", log_type, valid_values, self.name) validator.check_type_name("log_type", log_type, valid_values, self.name)
return log_type return log_type
class LambApplyOptimizerAssign(PrimitiveWithInfer):
r"""
Updates gradients by LAMB optimizer algorithm. Get the compute ratio.
The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
<https://arxiv.org/abs/1904.00962>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
m = \frac{m}{1 - \beta_1^t} \\
v = \frac{v}{1 - \beta_2^t} \\
r = \frac{m}{\sqrt{v} + \epsilon} \\
w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w))
\end{array}
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
`beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents
`epsilon`.
Inputs:
- **gradient** (Tensor) - Gradient of parameters, float32/float16.
- **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`.
- **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`.
- **var** (Tensor) - Weights to be updated, has the same type as `gradient`.
- **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16.
- **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`.
- **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`.
- **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`.
- **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`.
- **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`.
- **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`.
- **decay_flag** (Tensor) -Specify whether param upadte with weight decay, has the same type as `beta1`.
- **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`.
Outputs:
Tensor, the compute ratio r.
- **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`.
- **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace,
has the same type as `gradient`.
- **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace,
has the same type as `gradient`.
Supported Platforms:
``Ascend``
"""
@prim_attr_register
def __init__(self):
"""Initialize LambApplyOptimizerAssign"""
def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return m_shape, v_shape, m_shape
def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype,
"eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype,
"weight_decay": weight_decay_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
return m_dtype, v_dtype, v_dtype
class LambApplyWeightAssign(PrimitiveWithInfer):
r"""
Updates gradients by LAMB optimizer algorithm. The weight update part.
The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
<https://arxiv.org/abs/1904.00962>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
m = \frac{m}{1 - \beta_1^t} \\
v = \frac{v}{1 - \beta_2^t} \\
r = \frac{m}{\sqrt{v} + \epsilon} \\
w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w))
\end{array}
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
`beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents
`epsilon`.
Inputs:
- **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16.
- **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`.
- **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16.
- **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16.
- **var** (Tensor) - Weights to be updated, the same shape and type as `update`.
Outputs:
- **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs.
Supported Platforms:
``Ascend``
"""
@prim_attr_register
def __init__(self):
"""Initialize LambApplyWeightAssign"""
def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape):
validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name)
return var_shape
def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype):
args = {"var": var_dtype, "update": update_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
return var_dtype

View File

@ -229,7 +229,7 @@ def test_bert_performance():
# assertion occurs while the loss value, overflow state or loss_scale value is wrong # assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)
expect_loss_value = [10.235566, 10.207392, 10.206976] expect_loss_value = [11.325791, 11.285011, 11.284766]
print("loss value: {}".format(loss_value)) print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
@ -239,7 +239,7 @@ def test_bert_performance():
assert (overflow == expect_overflow).all() assert (overflow == expect_overflow).all()
loss_scale = np.array(callback.lossscale_list) loss_scale = np.array(callback.lossscale_list)
expect_loss_scale = [262144.0, 262144.0, 262144.0] expect_loss_scale = [65536.0, 65536.0, 65536.0]
print("loss scale: {}".format(loss_scale)) print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0, 0) assert np.allclose(loss_scale, expect_loss_scale, 0, 0)

View File

@ -225,8 +225,12 @@ def test_bert_percision(enable_graph_kernel=False):
loss_value = np.array(callback.loss_list) loss_value = np.array(callback.loss_list)
assert np.allclose(loss_value[0], 12.2065868, 0, 0.000001) assert np.allclose(loss_value[0], 12.2065868, 0, 0.000001)
if enable_graph_kernel:
expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466, expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466,
12.6212320, 12.2229223, 12.4272099] 12.6212320, 12.2229223, 12.4272099]
else:
expect_loss_value = [12.2065868, 11.94102, 11.931558, 11.938105, 11.932648, 12.556579, 12.130686, 12.783716,
12.360179, 12.578461]
print("loss value: {}".format(loss_value)) print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)