diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index dd2abcfe04f..e819261c649 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -53,6 +53,7 @@ static std::map tbe_func_adapter_map = { {"scatter_nd", "scatter_nd_d"}, {"tile", "tile_d"}, {"gather_v2", "gather_v2_d"}, + {"sparse_gather_v2", "gather_v2_d"}, {"batch_mat_mul", "batch_matmul"}, {"b_n_training_reduce", "bn_training_reduce"}, {"b_n_training_update", "bn_training_update"}, diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index c7e63c9a414..6a557388adf 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -47,6 +47,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimCumProd->name(), {1}); Register(prim::kPrimReduceAll->name(), {1}); Register(prim::kPrimUnsortedSegmentMin->name(), {2}); + Register(kSparseGatherV2, {2}); Register(kUnsortedSegmentProdOpName, {2}); Register(kSimpleMeanGradOpName, {1}); Register(kMeanGradOpName, {1}); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index b2771f4b9b7..2b194088995 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -65,6 +65,7 @@ constexpr auto kScatterNdOpName = "ScatterNd"; constexpr auto kStridedSliceAssignOpName = "StridedSliceAssign"; constexpr auto kStridedSliceOpName = "StridedSlice"; constexpr auto kStridedSliceGradOpName = "StridedSliceGrad"; +constexpr auto kSparseGatherV2 = "SparseGatherV2"; constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd"; constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin"; constexpr auto kFlattenGradOpName = "FlattenGrad"; diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index ba86f994bf2..484827ec182 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -248,3 +248,4 @@ from .range import _range_tbe from .fused_mul_add_n_l2loss import _fused_mul_add_n_l2loss_tbe from .fused_mul_apply_momentum_extern import _fused_mul_apply_momentum_extern_tbe from .lamb_next_right import _lamb_next_right_tbe +from .sparse_gather_v2 import _sparse_gather_v2_tbe diff --git a/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py b/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py new file mode 100644 index 00000000000..b8248363127 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py @@ -0,0 +1,66 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseGatherV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_gather_v2_op_info = TBERegOp("SparseGatherV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("gather_v2_d.so") \ + .compute_cost(10) \ + .kernel_name("gather_v2_d") \ + .partial_flag(True) \ + .attr("axis", "optional", "int", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(sparse_gather_v2_op_info) +def _sparse_gather_v2_tbe(): + """SparseGatherV2 TBE register""" + return diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 03eeb9cecf5..7b3a48e25af 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -956,6 +956,11 @@ test_case_nn_ops = [ 'desc_const': [0], 'desc_inputs': [[1152], Tensor(np.array(10).astype(np.int32))], 'desc_bprop': [Tensor(np.array(10).astype(np.float32))]}), + ('SparseGatherV2_0', { + 'block': P.SparseGatherV2(), + 'desc_const': [0], + 'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))], + 'desc_bprop': [[2, 1, 2]]}), ('Range', { 'block': P.Range(1.0, 5.0), 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))],