diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index e37bf48cf8d..cfbcdb0bfd0 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6"; const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameElu[] = "Elu"; const char kNameEluGrad[] = "EluGrad"; +const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; const char kNameScatterUpdate[] = "ScatterUpdate"; const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; const char kNameScatterMax[] = "ScatterMax"; @@ -261,6 +262,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)}, + {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 568c794a08c..7643fa54e4f 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits())}, {"num", ATTR_DESC(num, AnyTraits())}}; DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; +// TensorScatterUpdate +INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}}; + // ScatterUpdate INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 916ae940765..6e0debd572d 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike) DECLARE_OP_USE_OUTPUT(ZerosLike) DECLARE_OP_ADAPTER(OnesLike) DECLARE_OP_USE_OUTPUT(OnesLike) +DECLARE_OP_ADAPTER(TensorScatterUpdate) +DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) DECLARE_OP_ADAPTER(ScatterUpdate) DECLARE_OP_USE_OUTPUT(ScatterUpdate) DECLARE_OP_ADAPTER(ScatterNdUpdate) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 43a35f99e09..07805d3b45f 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self): return bprop +@bprop_getters.register(P.TensorScatterUpdate) +def get_bprop_tensor_scatter_update(self): + """Generate bprop for TensorScatterUpdate""" + gather_nd = P.GatherNd() + tensor_scatter_update = P.TensorScatterUpdate() + + def bprop(x, indices, update, out, dout): + x_grad = tensor_scatter_update(dout, indices, zeros_like(update)) + update_grad = gather_nd(dout, indices) + return x_grad, zeros_like(indices), update_grad + + return bprop + + @bprop_getters.register(P.Argmax) def get_bprop_argmax(self): """Generate bprop for Argmax""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 7d719103206..f81e1983345 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -255,3 +255,4 @@ from .lamb_next_right import _lamb_next_right_tbe from .sparse_gather_v2 import _sparse_gather_v2_tbe from .data_format_dim_map import _data_format_dim_map_tbe from .histogram_fixed_width import _histogram_fixed_width_tbe +from .tensor_scatter_update import _tensor_scatter_update_tbe diff --git a/mindspore/ops/_op_impl/tbe/scatter_nd_update.py b/mindspore/ops/_op_impl/tbe/scatter_nd_update.py index df0996f26f7..74fb7c9b725 100644 --- a/mindspore/ops/_op_impl/tbe/scatter_nd_update.py +++ b/mindspore/ops/_op_impl/tbe/scatter_nd_update.py @@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/scatter_update.py b/mindspore/ops/_op_impl/tbe/scatter_update.py index 3c330fe4353..244b8ab21fa 100644 --- a/mindspore/ops/_op_impl/tbe/scatter_update.py +++ b/mindspore/ops/_op_impl/tbe/scatter_update.py @@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py b/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py new file mode 100644 index 00000000000..46d6b20357e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TensorScatterUpdate op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("tensor_scatter_update.so") \ + .compute_cost(10) \ + .kernel_name("tensor_scatter_update") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(1, "updates", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(tensor_scatter_update_op_info) +def _tensor_scatter_update_tbe(): + """TensorScatterUpdate TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 2db66bd8c62..92063135e1a 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, EmbeddingLookup, - Squeeze, StridedSlice, Tile, + Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo) @@ -212,6 +212,7 @@ __all__ = [ 'Pad', 'MirrorPad', 'GatherNd', + 'TensorScatterUpdate', 'ScatterUpdate', 'ScatterNdUpdate', 'Floor', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 061aa2e08fb..bd395d94f5b 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer): return x_dtype +class TensorScatterUpdate(PrimitiveWithInfer): + """ + Update tensor value by using input indices and value. + + Using given values to update tensor value, along with the input indices. + + Inputs: + - **input_x** (Tensor) - The target tensor. + - **indices** (Tensor) - The index of input tensor whose data type is int32. + - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, + and update.shape = indices.shape + input_x.shape[1:]. + + Outputs: + Tensor, has the same shape and type as `input_x`. + + Examples: + >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) + >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) + >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) + >>> op = P.TensorScatterUpdate() + >>> output = op(input_x, indices, update) + """ + @prim_attr_register + def __init__(self): + """Init TensorScatterUpdate""" + self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) + + def infer_shape(self, x_shape, indices_shape, value_shape): + validator.check('the dimension of x', len(x_shape), + 'the dimension of indices', indices_shape[-1], Rel.GE) + if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: + raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.") + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, value_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + args = {"x": x_dtype, "value": value_dtype} + validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + return x_dtype + + class ScatterUpdate(PrimitiveWithInfer): """ Update tensor value by using input indices and value. @@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer): def infer_shape(self, x_shape, indices_shape, value_shape): if indices_shape + x_shape[1:] != value_shape: - raise ValueError('Input value are not match with input indices.') + raise ValueError("For 'ScatterUpdate', input value are not match with input indices.") return x_shape def infer_dtype(self, x_dtype, indices_dtype, value_dtype): @@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): validator.check('the dimension of x', len(x_shape), 'the dimension of indices', indices_shape[-1], Rel.GE) if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: - raise ValueError('Input value are not match with input indices.') + raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.") return x_shape def infer_dtype(self, x_dtype, indices_dtype, value_dtype): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1308a83ae72..03d24375d55 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ import pipeline_for_compile_grad_ge_graph_for_case_by_case_config +def test_tensor_scatter_update(): + class TensorScatterUpdateNet(nn.Cell): + """TensorScatterUpdate net definition""" + + def __init__(self): + super(TensorScatterUpdateNet, self).__init__() + self.tensor_scatter_update = P.TensorScatterUpdate() + + def construct(self, x, i, u): + out = self.tensor_scatter_update(x, i, u) + return out + net = TensorScatterUpdateNet() + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32) + indices = Tensor(np.array([[0, 0], [1, 1]], np.int32)) + updates = Tensor(np.ones([2, 5], np.float32)) + net(x, indices, updates) + + class InputBackward(nn.Cell): def __init__(self, network): super(InputBackward, self).__init__() @@ -1537,6 +1556,12 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), Tensor(np.ones((2,), np.int32))), 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), + ('TensorScatterUpdate', { + 'block': P.TensorScatterUpdate(), + 'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32), + Tensor(np.array([[0, 1], [1, 2]], np.int32)), + Tensor(np.ones([2, 5], np.float32) * 99)), + 'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}), ('ScatterMax', { 'block': ScatterMax(), 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),