diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 2202f99910f..f412f80e78c 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -33,6 +33,7 @@ from .fused_mul_add import FusedMulAdd from .gelu import GeLU from .gelu_grad import GeLUGrad from .gkdropout import GkDropout +from .identity import Identity from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign from .lamb_apply_weight_assign import LambApplyWeightAssign from .layernorm import LayerNorm @@ -42,6 +43,7 @@ from .logsoftmax_grad import LogSoftmaxGrad from .matmul import BatchMatMul, MatMul from .maximum_grad import MaximumGrad from .minimum_grad import MinimumGrad +from .oneslike import OnesLike from .reduce_mean import ReduceMean from .relu import ReLU from .relu_grad import ReluGrad @@ -56,6 +58,7 @@ from .sqrt_grad import SqrtGrad from .square import Square from .square_sum_v1 import SquareSumV1 from .squared_difference import SquaredDifference +from .square_sum_all import SquareSumAll from .squeeze import Squeeze from .tanh_grad import TanhGrad from .tile import Tile diff --git a/mindspore/_extends/graph_kernel/expanders/identity.py b/mindspore/_extends/graph_kernel/expanders/identity.py new file mode 100644 index 00000000000..cc6705ee92e --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/identity.py @@ -0,0 +1,23 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for Identity""" +from ._utils import Expander + +class Identity(Expander): + """Identity expander""" + def _expand(self, graph_builder): + input_x = self.inputs[0] + result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape}) + return result diff --git a/mindspore/_extends/graph_kernel/expanders/oneslike.py b/mindspore/_extends/graph_kernel/expanders/oneslike.py new file mode 100644 index 00000000000..96ba01cd175 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/oneslike.py @@ -0,0 +1,24 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for OnesLike""" +from ._utils import Expander + +class OnesLike(Expander): + """OnesLike expander""" + def _expand(self, graph_builder): + input_x = self.inputs[0] + const_one = graph_builder.value(input_x.dtype, 1) + result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape}) + return result diff --git a/mindspore/_extends/graph_kernel/expanders/square_sum_all.py b/mindspore/_extends/graph_kernel/expanders/square_sum_all.py new file mode 100644 index 00000000000..0e1e36fde72 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/square_sum_all.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for SquareSumAll""" +from ._utils import Expander + +class SquareSumAll(Expander): + """SquareSumAll expander""" + def _check(self): + """check inputs""" + input_num = len(self.inputs) + if input_num != 2: + raise GKException("SquareSumAll inputs number should be 2, but got {}.".format(input_num)) + + def _expand(self, graph_builder): + """do expand""" + x0 = self.inputs[0] + x1 = self.inputs[1] + + ori_shape = x0.shape + axis = [] + for i, _ in enumerate(ori_shape): + axis.append(i) + + square_res0 = graph_builder.emit('Mul', [x0, x0]) + square_res1 = graph_builder.emit('Mul', [x1, x1]) + result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False}) + result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False}) + + return result0, result1 diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 6fd929de3d5..2c0debb6729 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -155,6 +155,9 @@ class PrimLib: 'Mul': Prim(ELEMWISE), 'Sub': Prim(ELEMWISE), 'Log': Prim(ELEMWISE), + 'IsNan': Prim(ELEMWISE), + 'IsInf': Prim(ELEMWISE), + 'IsFinite': Prim(ELEMWISE), 'Exp': Prim(ELEMWISE), 'Rsqrt': Prim(ELEMWISE), 'Sqrt': Prim(ELEMWISE), diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index cff9cbc04e7..d594752b11a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -94,6 +94,9 @@ std::vector GetClusterableOpList() { prim::kPrimLogicalAnd, prim::kPrimLogicalOr, prim::kPrimLogicalNot, + prim::kPrimIsNan, + prim::kPrimIsInf, + prim::kPrimIsFinite, #endif }; const auto &flags = context::GraphKernelFlags::GetInstance(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index c7cd84ecbf2..d7eac8960da 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -88,6 +88,9 @@ std::vector GetExpandOps() { prim::kPrimSquaredDifference, prim::kPrimSqueeze, prim::kPrimEqualCount, + prim::kPrimSquareSumAll, + prim::kPrimIdentityMath, + prim::kPrimOnesLike, #endif }; const auto &flags = context::GraphKernelFlags::GetInstance(); diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index e4c3bcea4ea..4e26022a162 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -2580,7 +2580,9 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt // Check the data source actors. for (const auto &data_source_actor : actor_set->data_source_actors_) { MS_EXCEPTION_IF_NULL(data_source_actor); - if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() == 0) { + if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() + + data_source_actor->output_control_arrows_.size() == + 0) { MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user."; return false; } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index c8b3252cacc..0471fa22c0e 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -488,6 +488,10 @@ inline const PrimitivePtr kPrimAtanGrad = std::make_shared("AtanGrad" inline const PrimitivePtr kPrimFloorMod = std::make_shared("FloorMod"); inline const PrimitivePtr kPrimWhere = std::make_shared("Where"); inline const PrimitivePtr kPrimIdentityMath = std::make_shared("Identity", kSideEffectPropagate); +inline const PrimitivePtr kPrimIsNan = std::make_shared("IsNan"); +inline const PrimitivePtr kPrimIsInf = std::make_shared("IsInf"); +inline const PrimitivePtr kPrimIsFinite = std::make_shared("IsFinite"); +inline const PrimitivePtr kPrimSquareSumAll = std::make_shared("SquareSumAll"); // Statements inline const PrimitivePtr kPrimReturn = std::make_shared("Return"); diff --git a/tests/st/ops/graph_kernel/test_identity.py b/tests/st/ops/graph_kernel/test_identity.py new file mode 100644 index 00000000000..08365c682e2 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_identity.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.identity = P.Identity() + + def construct(self, x): + return self.identity(x) + +def get_output(x, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = Net() + output = net(x) + return output + +def test_basic(dtype): + expect_np = np.random.normal(0, 10, (16, 32)).astype(dtype) + x = Tensor(expect_np) + output = get_output(x, True) + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic(np.float32) diff --git a/tests/st/ops/graph_kernel/test_oneslike.py b/tests/st/ops/graph_kernel/test_oneslike.py new file mode 100644 index 00000000000..5f4f6e14baa --- /dev/null +++ b/tests/st/ops/graph_kernel/test_oneslike.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P +import mindspore.common.dtype as mstype + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.oneslike = P.OnesLike() + + def construct(self, shape, dtype, x): + return self.oneslike(x) + +def get_output(shape, dtype, nptype, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = Net() + x = Tensor(np.random.normal(0, 10, shape).astype(nptype)) + output = net(shape, dtype, x) + return output + +def test_basic(shape, dtype, nptype): + expect = get_output(shape, dtype, nptype, False) + output = get_output(shape, dtype, nptype, True) + assert np.allclose(expect.asnumpy(), output.asnumpy(), 1.e-4, 1.e-7) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic((2, 16), mstype.float16, np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic((4, 32), mstype.float32, np.float32) diff --git a/tests/st/ops/graph_kernel/test_square_sum_all.py b/tests/st/ops/graph_kernel/test_square_sum_all.py new file mode 100644 index 00000000000..2a5c3559ad0 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_square_sum_all.py @@ -0,0 +1,61 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.squaresumall = P.SquareSumAll() + + def construct(self, x0, x1): + return self.squaresumall(x0, x1) + +def get_output(inp0, inp1, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = Net() + output = net(inp0, inp1) + return output + +def test_basic(datatype): + inp0 = Tensor(np.random.normal(1, 0.1, [800, 96]).astype(datatype)) + inp1 = Tensor(np.random.normal(1, 0.1, [800, 96]).astype(datatype)) + expect = get_output(inp0, inp1, False) + output = get_output(inp0, inp1, True) + expect_np0 = expect[0].asnumpy().copy() + output_np0 = output[0].asnumpy().copy() + expect_np1 = expect[1].asnumpy().copy() + output_np1 = output[1].asnumpy().copy() + assert np.allclose(expect_np0, output_np0, 1.e-4, 1.e-7) + assert np.allclose(expect_np1, output_np1, 1.e-4, 1.e-7) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic(np.float32)