diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index c32cb3b3ac3..9e69cc7445c 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -87,7 +87,8 @@ static std::map tbe_func_adapter_map = { {"i_ou", "iou"}, {"s_gd", "sgd"}, {"l_ars_update", "lars_v2_update"}, - {"n_ms_with_mask", "nms_with_mask"}}; + {"n_ms_with_mask", "nms_with_mask"}, + {"square_sum_all", "square_sum_all"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { if (func_name == nullptr) { diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 915601b44f6..ac5c6dcd922 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -198,6 +198,7 @@ const char kNameApplyRMSProp[] = "ApplyRMSProp"; const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; const char kNameL2Loss[] = "L2Loss"; const char kNameCTCLoss[] = "CTCLoss"; +const char kNameSquareSumAll[] = "SquareSumAll"; // -----------------OpAdapter initialization-------------- std::unordered_map &DfGraphConvertor::get_adpt_map() { @@ -395,7 +396,8 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, {string(kNameL2Loss), ADPT_DESC(L2Loss)}, - {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}}; + {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, + {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}}; #ifdef ENABLE_GE adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdam); diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index f6dae27455b..654d80f23e3 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -889,6 +889,11 @@ INPUT_MAP(Square) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Square) = EMPTY_ATTR_MAP; OUTPUT_MAP(Square) = {{0, OUTPUT_DESC(y)}}; +// SquareSumAll +INPUT_MAP(SquareSumAll) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(SquareSumAll) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SquareSumAll) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + // Tanh INPUT_MAP(Tanh) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Tanh) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 7671b2a631f..8820483e89f 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -190,6 +190,8 @@ DECLARE_OP_ADAPTER(SplitD) DECLARE_OP_USE_DYN_OUTPUT(SplitD) DECLARE_OP_ADAPTER(SGD) DECLARE_OP_USE_OUTPUT(SGD) +DECLARE_OP_ADAPTER(SquareSumAll) +DECLARE_OP_USE_OUTPUT(SquareSumAll) DECLARE_OP_ADAPTER(Tanh) DECLARE_OP_USE_OUTPUT(Tanh) diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index d2bd1eb9541..c8aa30f2c25 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -180,3 +180,4 @@ from .random_choice_with_mask import random_choice_with_mask_op_info from .sgd import sgd_op_info from .lars_update import lars_update_op_info from .bn_training_update_v2 import _bn_training_update_v2_tbe +from .square_sum_all import square_sum_all_op_info diff --git a/mindspore/ops/_op_impl/tbe/square_sum_all.py b/mindspore/ops/_op_impl/tbe/square_sum_all.py new file mode 100644 index 00000000000..e9d56e44b13 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/square_sum_all.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SquareSumAll op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +square_sum_all_op_info = TBERegOp("SquareSumAll") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("square_sum_all.so") \ + .compute_cost(10) \ + .kernel_name("square_sum") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y1", False, "required", "all") \ + .output(1, "y2", False, "required", "all") \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(square_sum_all_op_info) +def _square_sum_all_tbe(): + """SquareSumAll TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 235b593c7af..94e9913241e 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -47,7 +47,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, Sin, Sqrt, Rsqrt, - Square, Sub, TensorAdd, Sign, Round) + Square, Sub, TensorAdd, Sign, Round, SquareSumAll) from .random_ops import (RandomChoiceWithMask) from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, @@ -251,7 +251,8 @@ __all__ = [ "BatchToSpace", "Atan2", "ApplyRMSProp", - "ApplyCenteredRMSProp" + "ApplyCenteredRMSProp", + "SquareSumAll" ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fc92bfc8fb5..2533de0e466 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2090,3 +2090,38 @@ class Atan2(_MathBinaryOp): >>> atan2(input_x, input_y) [[0. 0.7853982]] """ + + +class SquareSumAll(PrimitiveWithInfer): + """ + Returns square sum all of a tensor element-wise + + Inputs: + - **input_x1** (Tensor) - The input tensor. + - **input_x2** (Tensor) - The input tensor same type and shape as the `input_x1`. + + Note: + SquareSumAll only supports float16 and float32 data type. + + Outputs: + - **output_y1** (Tensor) - The same type as the `input_x1`. + - **output_y2** (Tensor) - The same type as the `input_x1`. + + Examples: + >>> input_x1 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32) + >>> input_x2 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32) + >>> square_sum_all = P.SquareSumAll() + >>> square_sum_all(input_x1, input_x2) + """ + + @prim_attr_register + def __init__(self): + """init SquareSumAll""" + def infer_shape(self, x_shape, y_shape): + validator.check("x1_shape", x_shape, "x2_shape", y_shape, Rel.EQ, self.name) + return [], [] + + def infer_dtype(self, x_type, y_type): + validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name) + return x_type, y_type diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 5fc201f9517..a02290a4c48 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -573,7 +573,12 @@ test_case_math_ops = [ 'block': P.Atan2(), 'desc_inputs': [Tensor(np.array([0, 1]).astype(np.float32)), Tensor(np.array([1, 1]).astype(np.float32))], - 'desc_bprop': [[2]]}) + 'desc_bprop': [[2]]}), + ('SquareSumAll', { + 'block': P.SquareSumAll(), + 'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)), + Tensor(np.array([1, 1, 3, 7]).astype(np.float32))], + 'skip': ['backward']}), ] test_case_nn_ops = [