diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 4429028d9c3..2202f99910f 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -24,6 +24,8 @@ from .clip_by_norm_no_div_sum import ClipByNormNoDivSum from .conv2d import Conv2D from .complex import CAbs, CAdd, CDiv, CMul, CSub from .dropout_grad import DropoutGrad +from .equal_count import EqualCount +from .erfc import Erfc from .expand_dims import ExpandDims from .fused_adam import FusedAdam from .fused_adam_weight_decay import FusedAdamWeightDecay @@ -57,4 +59,3 @@ from .squared_difference import SquaredDifference from .squeeze import Squeeze from .tanh_grad import TanhGrad from .tile import Tile -from .equal_count import EqualCount diff --git a/mindspore/_extends/graph_kernel/expanders/erfc.py b/mindspore/_extends/graph_kernel/expanders/erfc.py new file mode 100644 index 00000000000..7e97c455a4e --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/erfc.py @@ -0,0 +1,35 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for erfc""" +from ._utils import Expander + + +class Erfc(Expander): + """Erfc expander""" + + def _expand(self, graph_builder): + input_x = self.inputs[0] + result = None + if input_x.dtype == "float16": + const_one = graph_builder.value("float32", 1) + input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) + erf_result = graph_builder.emit('Erf', [input_x]) + result = graph_builder.emit('Sub', [const_one, erf_result]) + result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"}) + return result + const_one = graph_builder.value(input_x.dtype, 1) + erf_result = graph_builder.emit('Erf', [input_x]) + result = graph_builder.emit('Sub', [const_one, erf_result]) + return result diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 9c16cab1a28..6fd929de3d5 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -158,7 +158,14 @@ class PrimLib: 'Exp': Prim(ELEMWISE), 'Rsqrt': Prim(ELEMWISE), 'Sqrt': Prim(ELEMWISE), + 'Div': Prim(ELEMWISE), + 'FloorDiv': Prim(ELEMWISE), 'RealDiv': Prim(ELEMWISE), + 'Mod': Prim(ELEMWISE), + 'Floor': Prim(ELEMWISE), + 'FloorMod': Prim(ELEMWISE), + 'Erf': Prim(ELEMWISE), + 'Erfc': Prim(ELEMWISE), 'Cast': Prim(ELEMWISE), 'Pow': Prim(ELEMWISE), 'Minimum': Prim(ELEMWISE), diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index e6fdb29e20a..cff9cbc04e7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -60,17 +60,8 @@ std::vector GetClusterableOpList() { prim::kPrimRealDiv, prim::kPrimReduceSum, prim::kPrimEqual, - prim::kPrimNotEqual, - prim::kPrimLogicalAnd, - prim::kPrimLogicalOr, - prim::kPrimLogicalNot, prim::kPrimAssign, prim::kPrimInplaceAssign, - prim::kPrimAtan, - prim::kPrimAtan2, - prim::kPrimExpm1, - prim::kPrimAsinh, - prim::kPrimAcosh, #if ENABLE_D prim::kPrimMatMul, prim::KPrimTransData, @@ -88,6 +79,21 @@ std::vector GetClusterableOpList() { prim::kPrimGreaterEqual, prim::kPrimLessEqual, prim::kPrimSelect, + prim::kPrimAtan, + prim::kPrimAtan2, + prim::kPrimExpm1, + prim::kPrimAsinh, + prim::kPrimAcosh, + prim::kPrimDiv, + prim::kPrimFloorDiv, + prim::kPrimMod, + prim::kPrimFloor, + prim::kPrimFloorMod, + prim::kPrimErf, + prim::kPrimNotEqual, + prim::kPrimLogicalAnd, + prim::kPrimLogicalOr, + prim::kPrimLogicalNot, #endif }; const auto &flags = context::GraphKernelFlags::GetInstance(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index cdd60842833..c7cd84ecbf2 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -60,6 +60,7 @@ std::vector GetExpandOps() { prim::kPrimTile, prim::kPrimMatMul, prim::kPrimBatchMatMul, + prim::kPrimErfc, #if ENABLE_D prim::kPrimSqrtGrad, prim::kPrimClipByNormNoDivSum, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index e2558b9d05a..51ee1bc0217 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -622,6 +622,7 @@ inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared("La inline const PrimitivePtr kPrimDType = std::make_shared("DType"); inline const PrimitivePtr kPrimDivFusion = std::make_shared("DivFusion"); inline const PrimitivePtr kPrimErf = std::make_shared("Erf"); +inline const PrimitivePtr kPrimErfc = std::make_shared("Erfc"); inline const PrimitivePtr kPrimSplice = std::make_shared("Splice"); inline const PrimitivePtr kPrimAffine = std::make_shared("Affine"); diff --git a/tests/st/ops/graph_kernel/test_erf_erfc.py b/tests/st/ops/graph_kernel/test_erf_erfc.py new file mode 100644 index 00000000000..d93a15eec0a --- /dev/null +++ b/tests/st/ops/graph_kernel/test_erf_erfc.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class ErfNet(nn.Cell): + def __init__(self): + super(ErfNet, self).__init__() + self.erf = P.Erf() + + def construct(self, x): + return self.erf(x) + +class ErfcNet(nn.Cell): + def __init__(self): + super(ErfcNet, self).__init__() + self.erfc = P.Erfc() + + def construct(self, x): + return self.erfc(x) + +def get_output(net, inp, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + output = net()(inp) + return output + +def basic_test(net, datatype): + inp = Tensor(np.random.random((2, 3)).astype(datatype)) + expect = get_output(net, inp, False) + output = get_output(net, inp, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) + + inp = Tensor(np.random.random((2, 3, 3, 4, 5)).astype(datatype)) + expect = get_output(net, inp, False) + output = get_output(net, inp, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_fp16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + basic_test(ErfNet, np.float16) + basic_test(ErfcNet, np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_fp32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + basic_test(ErfNet, np.float32) + basic_test(ErfcNet, np.float32)