From 0271535429811d3067968c7fda0d8a61c6b77189 Mon Sep 17 00:00:00 2001 From: chenlei_autodiff Date: Mon, 9 Aug 2021 15:37:57 +0800 Subject: [PATCH] [GraphKernel] fix bert and add graph kernel ops. --- .jenkins/check/config/filter_pylint.txt | 1 + akg | 2 +- .../graph_kernel/expanders/__init__.py | 1 + .../_extends/graph_kernel/expanders/slice.py | 35 ++++++++++++ .../graph_kernel/model/graph_split.py | 11 ++++ .../_extends/graph_kernel/model/model.py | 1 + .../graph_kernel/graph_kernel_cluster.cc | 1 + .../graph_kernel/graph_kernel_expander.cc | 1 + tests/st/ops/graph_kernel/test_slice.py | 55 +++++++++++++++++++ 9 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 mindspore/_extends/graph_kernel/expanders/slice.py create mode 100644 tests/st/ops/graph_kernel/test_slice.py diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index 1b8f94c2700..51ba5c814be 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -38,6 +38,7 @@ "mindspore/model_zoo/official/cv" "c-extension-no-member" "mindspore/model_zoo/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name" "mindspore/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py" "Catching too general exception BaseException" +"mindspore/mindspore/_extends/graph_kernel/model/model.py" "super-on-old-class" # MindData "mindspore/mindspore/dataset/__init__.py" "redefined-builtin" diff --git a/akg b/akg index 15b59fb7399..8902440c825 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 15b59fb739944c1903558659a39b34bb632de448 +Subproject commit 8902440c825f90846a5b0fe5c1644d450dbab631 diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index f412f80e78c..9cc0b8a85cf 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -51,6 +51,7 @@ from .sigmoid import Sigmoid from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad from .sigmoid_grad import SigmoidGrad +from .slice import Slice from .softmax import Softmax from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits from .softmax_grad_ext import SoftmaxGradExt diff --git a/mindspore/_extends/graph_kernel/expanders/slice.py b/mindspore/_extends/graph_kernel/expanders/slice.py new file mode 100644 index 00000000000..f886cfdb238 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/slice.py @@ -0,0 +1,35 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for slice""" +from ._utils import Expander, ExpanderInfoValidator as VLD + + +@VLD.check_attrs('begin', 'size') +class Slice(Expander): + """Slice expander""" + + def _expand(self, graph_builder): + input_x = self.inputs[0] + begin = self.attrs['begin'] + size = self.attrs['size'] + end = [] + strides = [] + for i in range(len(begin)): + strides.append(1) + end.append(begin[i] + size[i]) + output = graph_builder.tensor(size, input_x.dtype, input_x.data_format) + graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides}) + + return output diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 363401992eb..6c422706368 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -804,6 +804,16 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, True + def _strided_slice(dom): + if dom.dom_op().prim != "StridedSlice": + return None + fused = [] + for a, _ in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ + len(a.out_relations) == 1 and not a.is_output: + fused.append(a) + return fused, True + def _fuse_loop(): changed = True while changed: @@ -814,6 +824,7 @@ class GraphSplitGpu(GraphSplitByPattern): changed = self.fuse(_reduce_width) or changed changed = self.fuse(_broadcast_depth) or changed changed = self.fuse(_broadcast_width) or changed + changed = self.fuse(_strided_slice) or changed if use_poly_reduce: changed = self.fuse(_reduce_output) or changed if enable_stitch_fusion: diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 4dcec3e1466..634942e108f 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -216,6 +216,7 @@ class PrimLib: 'Transpose': Prim(OPAQUE), 'Tile': Prim(BROADCAST), 'BroadcastTo': Prim(BROADCAST), + 'StridedSlice': Prim(OPAQUE), 'MatMul': Prim(OPAQUE), 'TransData': Prim(OPAQUE), 'BatchMatMul': Prim(OPAQUE), diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index 2d26a864548..e30bd389085 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -99,6 +99,7 @@ std::vector GetClusterableOpList() { prim::kPrimSelect, prim::kPrimSign, prim::kPrimSin, + prim::kPrimStridedSlice, #endif }; const auto &flags = context::GraphKernelFlags::GetInstance(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 2c30e4b02e1..d18b4a2b169 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -82,6 +82,7 @@ std::vector GetExpandOps() { prim::kPrimSigmoidGrad, prim::kPrimSigmoidCrossEntropyWithLogits, prim::kPrimSigmoidCrossEntropyWithLogitsGrad, + prim::kPrimSlice, prim::kPrimSoftmax, prim::kPrimSoftmaxCrossEntropyWithLogits, prim::kPrimSquaredDifference, diff --git a/tests/st/ops/graph_kernel/test_slice.py b/tests/st/ops/graph_kernel/test_slice.py new file mode 100644 index 00000000000..118675b46db --- /dev/null +++ b/tests/st/ops/graph_kernel/test_slice.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.slice = P.Slice() + + def construct(self, x, begin, size): + return self.slice(x, begin, size) + + +def get_output(x, begin, size, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = Net() + output = net(x, begin, size) + return output + + +def test_slice(): + in1 = np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) + x1 = Tensor(in1) + begin1 = (0, 1, 0) + size1 = (2, 1, 3) + expect = get_output(x1, begin1, size1, False) + output = get_output(x1, begin1, size1, True) + assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_slice_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_slice()