From 53043ae18f16f70a72100c7100989d9a208035c9 Mon Sep 17 00:00:00 2001 From: zengzitao Date: Thu, 5 Nov 2020 16:39:05 +0800 Subject: [PATCH] support expand fused_adam and fused_adam_weight_decay op --- akg | 2 +- .../graph_kernel/expanders/__init__.py | 2 + .../graph_kernel/expanders/fused_adam.py | 71 +++++++++++++++++ .../expanders/fused_adam_weight_decay.py | 76 +++++++++++++++++++ .../_extends/graph_kernel/model/model.py | 2 + .../graph_kernel/model/model_builder.py | 1 + .../graph_kernel/arithmetic_simplify.cc | 4 +- .../graph_kernel/graph_kernel_helper.cc | 5 +- mindspore/core/base/core_ops.h | 2 + mindspore/ops/_op_impl/akg/gpu/__init__.py | 1 + .../ops/_op_impl/akg/gpu/inplace_assign.py | 41 ++++++++++ mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/other_ops.py | 33 ++++++++ tests/st/ops/graph_kernel/test_gelu.py | 2 + 14 files changed, 238 insertions(+), 6 deletions(-) create mode 100644 mindspore/_extends/graph_kernel/expanders/fused_adam.py create mode 100644 mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py create mode 100644 mindspore/ops/_op_impl/akg/gpu/inplace_assign.py diff --git a/akg b/akg index f308919c398..2956e64803c 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit f308919c39811c2c3e07fb0dcc8054a533c84cbc +Subproject commit 2956e64803cad9b84316cdf2b25d034c5f944ccc diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index e32ce450c8b..f0bdd539696 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -21,3 +21,5 @@ from .softmax import expand_softmax from .square import expand_square from .bias_add import expand_biasadd from .bias_add_grad import expand_biasaddgrad +from .fused_adam import expand_fusedadam +from .fused_adam_weight_decay import expand_fusedadamweightdecay diff --git a/mindspore/_extends/graph_kernel/expanders/fused_adam.py b/mindspore/_extends/graph_kernel/expanders/fused_adam.py new file mode 100644 index 00000000000..beea8b3ca03 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/fused_adam.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for fused_adam""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_fusedadam(expand_info): + """FusedAdma expander""" + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + input_desc_2 = expand_info['input_desc'][2] + input_desc_3 = expand_info['input_desc'][3] + input_desc_4 = expand_info['input_desc'][4] + input_desc_5 = expand_info['input_desc'][5] + input_desc_6 = expand_info['input_desc'][6] + input_desc_7 = expand_info['input_desc'][7] + input_desc_8 = expand_info['input_desc'][8] + input_desc_9 = expand_info['input_desc'][9] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) + one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) + eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) + lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) + param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) + m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) + v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) + gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) + graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient) + + # compute result + beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) + one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) + next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) + beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) + grad_square = graph_builder.emit('Mul', [gradient, gradient]) + one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) + next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) + sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) + sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) + update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) + update_with_lr = graph_builder.emit('Mul', [lr, update]) + next_para = graph_builder.emit('Sub', [param, update_with_lr]) + + param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) + m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) + v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) + + # set graph output. + graph_scope.set_output(param_result, m_result, v_result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py b/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py new file mode 100644 index 00000000000..d01170be603 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for fused_adam_weight_decay""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_fusedadamweightdecay(expand_info): + """FusedAdmaWeightDecay expander""" + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + input_desc_2 = expand_info['input_desc'][2] + input_desc_3 = expand_info['input_desc'][3] + input_desc_4 = expand_info['input_desc'][4] + input_desc_5 = expand_info['input_desc'][5] + input_desc_6 = expand_info['input_desc'][6] + input_desc_7 = expand_info['input_desc'][7] + input_desc_8 = expand_info['input_desc'][8] + input_desc_9 = expand_info['input_desc'][9] + input_desc_10 = expand_info['input_desc'][10] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) + one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format']) + eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format']) + lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format']) + param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format']) + m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format']) + v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format']) + gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format']) + weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format']) + graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, + eps, lr, param, m, v, gradient, weight_decay) + + # compute result + beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) + one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) + next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad]) + beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) + grad_square = graph_builder.emit('Mul', [gradient, gradient]) + one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) + next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) + sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) + sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps]) + update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) + param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) + update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay]) + update_with_lr = graph_builder.emit('Mul', [lr, update]) + next_para = graph_builder.emit('Sub', [param, update_with_lr]) + + para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) + m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True}) + v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True}) + + # set graph output. + graph_scope.set_output(para_result, m_result, v_result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index e8dd4fd15f2..5561ef213f7 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -154,6 +154,8 @@ class PrimLib: 'ControlDepend': Prim(CONTROL), 'Assign': Prim(ELEMWISE), 'Tanh': Prim(ELEMWISE), + 'ExpandDims': Prim(ELEMWISE), + 'InplaceAssign': Prim(ELEMWISE), '@ReduceInit': Prim(ELEMWISE), } diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 873568c7c08..8e51e2ce3d9 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -70,6 +70,7 @@ class OpInfer: infer_shape_func = { # add special infer func here + 'InplaceAssign': lambda inputs, attrs: inputs[2].shape } infer_dtype_func = { # add special infer func here diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index 908d33ca6d0..a80222c7408 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -560,7 +560,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { auto shape = GetNodeShape(node); if (shape.size() != 0 && shape.size() != 1) { - return node; + return nullptr; } else { auto tmp_node = node->cast(); auto transpose_node = tmp_node->input(1); @@ -635,7 +635,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); return new_cnode; } - return node; + return nullptr; }; auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 6526bc998d5..e2bd570762a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -702,7 +702,8 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector GetExpandOps() { std::unordered_set expand_ops = { - prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, prim::kPrimGeluGrad, + prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, + prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, }; return expand_ops; } @@ -729,7 +730,7 @@ std::vector GetFusibleOpList() { prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, - prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh}; + prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape}; return fusible_basic_ops; } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 47f9f83f466..2442dc7c713 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -174,6 +174,8 @@ inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared("SparseApplyFtrl"); inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared("SparseApplyProximalAdagrad"); +inline const PrimitivePtr kPrimFusedAdam = std::make_shared("FusedAdam"); +inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared("FusedAdamWeightDecay"); // Comm ops inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); diff --git a/mindspore/ops/_op_impl/akg/gpu/__init__.py b/mindspore/ops/_op_impl/akg/gpu/__init__.py index 62285cec437..bb1376bb75f 100644 --- a/mindspore/ops/_op_impl/akg/gpu/__init__.py +++ b/mindspore/ops/_op_impl/akg/gpu/__init__.py @@ -20,6 +20,7 @@ from .hsigmoid import _hsigmoid_akg from .hsigmoid_grad import _hsigmoid_grad_akg from .hswish import _hswish_akg from .hswish_grad import _hswish_grad_akg +from .inplace_assign import _inplace_assign_akg from .lessequal import _lessequal_akg from .logical_and import _logical_and_akg from .logical_not import _logical_not_akg diff --git a/mindspore/ops/_op_impl/akg/gpu/inplace_assign.py b/mindspore/ops/_op_impl/akg/gpu/inplace_assign.py new file mode 100644 index 00000000000..b1a934acc2b --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/inplace_assign.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""InplaceAssign op""" +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT + +op_info = AkgGpuRegOp("InplaceAssign") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .input(2, "z") \ + .output(0, "output") \ + .attr("fake_output", "optional", "bool") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ + .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ + .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _inplace_assign_akg(): + """InplaceAssign Akg register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5ae98954e8a..43e3a903efa 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -82,7 +82,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler) from . import _quant_ops from ._quant_ops import * -from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, +from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull) from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index b5b6673db7b..c3d4c1aa5df 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -65,6 +65,36 @@ class Assign(PrimitiveWithCheck): validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) +class InplaceAssign(PrimitiveWithInfer): + """ + Inplace assign `Parameter` with a value. + This primitive can only use in graph kernel. + Inputs: + - **variable** (Parameter) - The `Parameter`. + - **value** (Tensor) - The value to be assigned. + - **depend** (Tensor) - The dependent tensor to keep this op connected in graph. + Outputs: + Tensor, has the same type as original `variable`. + Examples: + >>> def construct(self, x): + >>> val = x - 1.0 + >>> ret = x + 2.0 + >>> return InplaceAssign()(x, val, ret) + >>> x = Tensor([2.0], mindspore.float32) + >>> net = Net() + >>> net(x) + """ + @ prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) + + def infer_shape(self, x, y, z): + return z + + def infer_dtype(self, x, y, z): + return z + + class BoundingBoxEncode(PrimitiveWithInfer): """ Encodes bounding boxes locations. @@ -509,6 +539,7 @@ class PopulationCount(PrimitiveWithInfer): validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) return mstype.tensor_type(mstype.uint8) + class Push(PrimitiveWithInfer): """ Pushes the inputs of the corresponding optimizer to parameter server. @@ -539,6 +570,7 @@ class Push(PrimitiveWithInfer): def infer_dtype(self, inputs, shapes): return mstype.uint64 + class Pull(PrimitiveWithInfer): """ Pulls weight from parameter server. @@ -563,6 +595,7 @@ class Pull(PrimitiveWithInfer): def infer_dtype(self, key_dtype, weight_dtype): return mstype.float32 + class identity(Primitive): """ Makes a identify primitive, used for pynative mode. diff --git a/tests/st/ops/graph_kernel/test_gelu.py b/tests/st/ops/graph_kernel/test_gelu.py index 52e55b0bd5d..a208650f852 100644 --- a/tests/st/ops/graph_kernel/test_gelu.py +++ b/tests/st/ops/graph_kernel/test_gelu.py @@ -52,6 +52,7 @@ def CalGelu(x): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_gelu(): + np.random.seed(0) input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) net = GeluNet() @@ -67,6 +68,7 @@ def test_gelu(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_gelu_grad(): + np.random.seed(0) input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) input_y = CalGelu(input_x)