diff --git a/akg b/akg index 32af460cac1..97dc7e96c2f 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 32af460cac1bb7d76bc1fd41f5866107cfffe1b9 +Subproject commit 97dc7e96c2ffedf2e6e38310a903ffa205a6e656 diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index ae7f988724f..9c840e9d2b7 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -14,43 +14,44 @@ # ============================================================================ """expanders init""" +from .addn import AddN from .assign_add import AssignAdd from .batchnorm import BatchNorm from .batchnorm_grad import BatchNormGrad from .bias_add import BiasAdd from .bias_add_grad import BiasAddGrad from .clip_by_norm_no_div_sum import ClipByNormNoDivSum +from .conv2d import Conv2D from .dropout_grad import DropoutGrad from .expand_dims import ExpandDims from .fused_adam import FusedAdam from .fused_adam_weight_decay import FusedAdamWeightDecay +from .fused_mul_add import FusedMulAdd from .gelu import GeLU from .gelu_grad import GeLUGrad from .gkdropout import GkDropout +from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign +from .lamb_apply_weight_assign import LambApplyWeightAssign from .layernorm import LayerNorm from .layernorm_grad import LayerNormGrad from .logsoftmax import LogSoftmax from .logsoftmax_grad import LogSoftmaxGrad +from .matmul import BatchMatMul, MatMul from .maximum_grad import MaximumGrad from .minimum_grad import MinimumGrad from .reduce_mean import ReduceMean from .relu import ReLU from .relu_grad import ReluGrad -from .softmax import Softmax from .sigmoid import Sigmoid -from .sigmoid_grad import SigmoidGrad from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad +from .sigmoid_grad import SigmoidGrad +from .softmax import Softmax from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits +from .softmax_grad_ext import SoftmaxGradExt from .sqrt_grad import SqrtGrad from .square import Square +from .square_sum_v1 import SquareSumV1 from .squeeze import Squeeze from .tanh_grad import TanhGrad from .tile import Tile -from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign -from .lamb_apply_weight_assign import LambApplyWeightAssign -from .softmax_grad_ext import SoftmaxGradExt -from .square_sum_v1 import SquareSumV1 -from .fused_mul_add import FusedMulAdd -from .conv2d import Conv2D -from .matmul import MatMul, BatchMatMul diff --git a/mindspore/_extends/graph_kernel/expanders/addn.py b/mindspore/_extends/graph_kernel/expanders/addn.py new file mode 100644 index 00000000000..b566dbd6a75 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/addn.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for addn""" +from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException +from ._utils import Expander, ExpanderInfoValidator as VLD + + +@VLD.check_all_formats_same +class AddN(Expander): + """Expand AddN to multiple Adds""" + + def _check(self): + if len(self.inputs) < 2: + raise GKException("Inputs num of AddN should be greater than 1 but got {}".format(len(self.inputs))) + + def _expand(self, graph_builder): + result = self.inputs[0] + for inp in self.inputs[1:]: + result = graph_builder.emit('Add', [result, inp]) + return result diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index 2b576b1b0b5..ec57749f9eb 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -53,7 +53,6 @@ std::vector GetClusterableOpList() { prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, - prim::kPrimAddN, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 488d6c8bf11..68473856bba 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -44,6 +44,7 @@ constexpr size_t kLambWeightInputIdx = 4; std::vector GetExpandOps() { std::vector expand_ops = { + prim::kPrimAddN, prim::kPrimSquare, prim::kPrimGeLUGrad, prim::kPrimAssignAdd, diff --git a/tests/st/ops/graph_kernel/test_addn.py b/tests/st/ops/graph_kernel/test_addn.py new file mode 100644 index 00000000000..4027a51b7d8 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_addn.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.addn = P.AddN() + + def construct(self, *args): + return self.addn(*args) + + +def get_output(*tensors): + net = Net() + output = net(tensors) + return output + + +def test_basic(): + np.random.seed(0) + tensors = [] + expect = np.array([0], np.float32) + for _ in range(10): + t = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + expect = t + expect + tensors.append(Tensor(t)) + + output = get_output(*tensors).asnumpy() + + assert np.allclose(expect, output, 1.e-4, 1.e-7) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_basic_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_graph_kernel=True) + test_basic() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_basic_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_graph_kernel=True) + test_basic()