From 7d55cef106c07752a941a05aae15855b64f59dcf Mon Sep 17 00:00:00 2001 From: chenlei_autodiff Date: Mon, 19 Jul 2021 19:34:23 +0800 Subject: [PATCH] [GraphKernel] add sponge ops. --- akg | 2 +- .../graph_kernel/expanders/__init__.py | 1 + .../expanders/squared_difference.py | 30 ++++++++++ .../_extends/graph_kernel/model/model.py | 5 ++ .../graph_kernel/graph_kernel_cluster.cc | 5 ++ .../graph_kernel/graph_kernel_expander.cc | 1 + .../graph_kernel/test_squared_difference.py | 58 +++++++++++++++++++ 7 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 mindspore/_extends/graph_kernel/expanders/squared_difference.py create mode 100644 tests/st/ops/graph_kernel/test_squared_difference.py diff --git a/akg b/akg index 1e6b226a041..a26aa4330ed 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 1e6b226a0417d23d2d0a2333d5e80f13fe9e8d0f +Subproject commit a26aa4330ed1e15dede3b983e79c78290083b3ac diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index bdfaacba59d..8bec8e34b40 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -53,6 +53,7 @@ from .softmax_grad_ext import SoftmaxGradExt from .sqrt_grad import SqrtGrad from .square import Square from .square_sum_v1 import SquareSumV1 +from .squared_difference import SquaredDifference from .squeeze import Squeeze from .tanh_grad import TanhGrad from .tile import Tile diff --git a/mindspore/_extends/graph_kernel/expanders/squared_difference.py b/mindspore/_extends/graph_kernel/expanders/squared_difference.py new file mode 100644 index 00000000000..316b000e346 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/squared_difference.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for squared_difference""" +from ._utils import Expander, ExpanderInfoValidator as VLD + + +@VLD.check_all_formats_same +class SquaredDifference(Expander): + """SquaredDifference expander""" + + def _expand(self, graph_builder): + input_x = self.inputs[0] + input_y = self.inputs[1] + + sub_val = graph_builder.emit('Sub', [input_x, input_y]) + result = graph_builder.emit('Mul', [sub_val, sub_val]) + + return result diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 34b57f21dbe..e01c2ad777b 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -176,6 +176,11 @@ class PrimLib: 'ReduceMax': Prim(REDUCE), 'ReduceMin': Prim(REDUCE), 'Assign': Prim(ELEMWISE), + 'Sign': Prim(ELEMWISE), + 'Sin': Prim(ELEMWISE), + 'Cos': Prim(ELEMWISE), + 'Asin': Prim(ELEMWISE), + 'ACos': Prim(ELEMWISE), 'Tanh': Prim(ELEMWISE), 'InplaceAssign': Prim(ELEMWISE), '@ReduceInit': Prim(ELEMWISE), diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index ec57749f9eb..bec0157944a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -67,6 +67,11 @@ std::vector GetClusterableOpList() { prim::KPrimTransData, prim::kPrimBatchMatMul, #elif ENABLE_GPU + prim::kPrimSin, + prim::kPrimCos, + prim::kPrimAsin, + prim::kPrimACos, + prim::kPrimSign, prim::kPrimReduceMax, prim::kPrimReduceMin, prim::kPrimGreater, diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 7f13aab7477..55f72e2b291 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -84,6 +84,7 @@ std::vector GetExpandOps() { prim::kPrimSigmoidCrossEntropyWithLogits, prim::kPrimSigmoidCrossEntropyWithLogitsGrad, prim::kPrimSoftmaxCrossEntropyWithLogits, + prim::kPrimSquaredDifference, prim::kPrimSqueeze, #endif }; diff --git a/tests/st/ops/graph_kernel/test_squared_difference.py b/tests/st/ops/graph_kernel/test_squared_difference.py new file mode 100644 index 00000000000..f2d49069219 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_squared_difference.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.squared_difference = P.SquaredDifference() + + def construct(self, x, y): + return self.squared_difference(x, y) + + +def get_output(x, y, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = Net() + output = net(x, y) + return output + + +def test_squared_difference(shape1, shape2, dtype): + x = Tensor(np.random.normal(0, 10, shape1).astype(dtype)) + y = Tensor(np.random.normal(0, 10, shape2).astype(dtype)) + expect = get_output(x, y, False) + output = get_output(x, y, True) + + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + + assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_squared_difference_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_squared_difference((4, 3), (4, 3), np.float16) + test_squared_difference((6, 2), (1), np.int32) + test_squared_difference((1), (4, 3), np.float32)