diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 4639229c414..3f97fbf83c8 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -369,7 +369,7 @@ class CosineEmbeddingLoss(_Loss): >>> x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32) >>> y = Tensor(np.array([1,-1]), mindspore.int32) >>> cosine_embedding_loss = P.CosineEmbeddingLoss() - >>> cosine_embedding_loss(x1, x2, target) + >>> cosine_embedding_loss(x1, x2, y) [0.0003426671] """ def __init__(self, margin=0.0, reduction="mean"): diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 6067295010f..be3d3135540 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -276,4 +276,6 @@ from .lrn_grad import _lrn_grad_tbe from .scatter_max import _scatter_max_tbe from .scatter_min import _scatter_min_tbe from .scatter_sub import _scatter_sub_tbe +from .scatter_mul import _scatter_mul_tbe +from .scatter_div import _scatter_div_tbe from .mod import _mod_tbe diff --git a/mindspore/ops/_op_impl/tbe/scatter_div.py b/mindspore/ops/_op_impl/tbe/scatter_div.py new file mode 100644 index 00000000000..5c6572dfd78 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_div.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterDiv op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_div_op_info = TBERegOp("ScatterDiv") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_div.so") \ + .compute_cost(10) \ + .kernel_name("scatter_div") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(scatter_div_op_info) +def _scatter_div_tbe(): + """ScatterDiv TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_mul.py b/mindspore/ops/_op_impl/tbe/scatter_mul.py new file mode 100644 index 00000000000..dd77e33f6d5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_mul.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterMul op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_mul_op_info = TBERegOp("ScatterMul") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_mul.so") \ + .compute_cost(10) \ + .kernel_name("scatter_mul") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(scatter_mul_op_info) +def _scatter_mul_tbe(): + """ScatterMul TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 4e76f55cd49..bc4edce193b 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -25,8 +25,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, - SameTypeShape, ScatterAdd, ScatterSub, ScatterMax, ScatterMin, ScatterUpdate, - ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, + SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, + ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, @@ -215,6 +215,8 @@ __all__ = [ 'L2Normalize', 'ScatterAdd', 'ScatterSub', + 'ScatterMul', + 'ScatterDiv', 'ScatterNd', 'ScatterMax', 'ScatterMin', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b02d1b3c877..daefea64c80 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -38,6 +38,39 @@ from ..._c_expression import signature_dtype as sig_dtype from ..._c_expression import typing +class _ScatterOp(PrimitiveWithInfer): + """ + Define Scatter operators + """ + __mindspore_signature__ = ( + ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), + ('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + ) + @staticmethod + def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): + if updates_shape and updates_shape != indices_shape + x_shape[1:]: + raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " + f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " + f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") + + @prim_attr_register + def __init__(self, use_locking=False): + """Init _ScatterOp""" + validator.check_value_type('use_locking', use_locking, [bool], self.name) + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + + def infer_shape(self, x_shape, indices_shape, updates_shape): + _ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + args = {"x": x_dtype, "updates": updates_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + return x_dtype + + def _check_infer_attr_reduce(axis, keep_dims, prim_name): validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) validator.check_value_type('axis', axis, [int, tuple], prim_name) @@ -2221,7 +2254,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): return x_dtype -class ScatterUpdate(PrimitiveWithInfer): +class ScatterUpdate(_ScatterOp): """ Update tensor value by using input indices and value. @@ -2233,8 +2266,8 @@ class ScatterUpdate(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. - **indices** (Tensor) - The index of input tensor. With int32 data type. - - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, - and update.shape = indices.shape + input_x.shape[1:]. + - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input, + and updates.shape = indices.shape + input_x.shape[1:]. Outputs: Tensor, has the same shape and type as `input_x`. @@ -2243,27 +2276,17 @@ class ScatterUpdate(PrimitiveWithInfer): >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x") >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) - >>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]) - >>> update = Tensor(np_update, mindspore.float32) + >>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]) + >>> updates = Tensor(np_updates, mindspore.float32) >>> op = P.ScatterUpdate() - >>> output = op(input_x, indices, update) + >>> output = op(input_x, indices, updates) """ - __mindspore_signature__ = ( - ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) - ) @prim_attr_register def __init__(self, use_locking=True): """Init ScatterUpdate""" validator.check_value_type('use_locking', use_locking, [bool], self.name) - self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) - - def infer_shape(self, x_shape, indices_shape, value_shape): - if indices_shape + x_shape[1:] != value_shape: - raise ValueError("For 'ScatterUpdate', input value are not match with input indices.") - return x_shape + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) def infer_dtype(self, x_dtype, indices_dtype, value_dtype): validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) @@ -2323,14 +2346,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): return x_dtype -def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): - if updates_shape and updates_shape != indices_shape + x_shape[1:]: - raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " - f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " - f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") - - -class ScatterMax(PrimitiveWithInfer): +class ScatterMax(_ScatterOp): """ Update the value of the input tensor through the max operation. @@ -2364,18 +2380,8 @@ class ScatterMax(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) validator.check_value_type('use_locking', use_locking, (bool,), self.name) - def infer_shape(self, x_shape, indices_shape, updates_shape): - _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) - return x_shape - def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) - args = {"x": x_dtype, "updates": updates_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) - return x_dtype - - -class ScatterMin(PrimitiveWithInfer): +class ScatterMin(_ScatterOp): """ Update the value of the input tensor through the min operation. @@ -2403,24 +2409,8 @@ class ScatterMin(PrimitiveWithInfer): [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]] """ - @prim_attr_register - def __init__(self, use_locking=False): - """Init ScatterMin""" - self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) - validator.check_value_type('use_locking', use_locking, (bool,), self.name) - def infer_shape(self, x_shape, indices_shape, updates_shape): - _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) - return x_shape - - def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) - args = {"x": x_dtype, "updates": updates_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) - return x_dtype - - -class ScatterAdd(PrimitiveWithInfer): +class ScatterAdd(_ScatterOp): """ Update the value of the input tensor through the add operation. @@ -2448,23 +2438,8 @@ class ScatterAdd(PrimitiveWithInfer): [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] """ - @prim_attr_register - def __init__(self, use_locking=False): - """Init ScatterAdd""" - validator.check_value_type('use_locking', use_locking, (bool,), self.name) - def infer_shape(self, x_shape, indices_shape, updates_shape): - _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) - return x_shape - - def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) - args = {'x': x_dtype, 'updates': updates_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) - return x_dtype - - -class ScatterSub(PrimitiveWithInfer): +class ScatterSub(_ScatterOp): """ Update the value of the input tensor through the sub operation. @@ -2492,20 +2467,63 @@ class ScatterSub(PrimitiveWithInfer): [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]] """ - @prim_attr_register - def __init__(self, use_locking=False): - """Init ScatterSub""" - validator.check_value_type('use_locking', use_locking, (bool,), self.name) - def infer_shape(self, x_shape, indices_shape, updates_shape): - _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) - return x_shape +class ScatterMul(_ScatterOp): + """ + Update the value of the input tensor through the mul operation. - def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) - args = {'x': x_dtype, 'updates': updates_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) - return x_dtype + Using given values to update tensor value through the mul operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do mul operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the mul operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([0, 1]), mindspore.int32) + >>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32) + >>> scatter_mul = P.ScatterMul() + >>> output = scatter_mul(input_x, indices, updates) + [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] + """ + + +class ScatterDiv(_ScatterOp): + """ + Update the value of the input tensor through the div operation. + + Using given values to update tensor value through the div operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do div operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the div operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([0, 1]), mindspore.int32) + >>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32) + >>> scatter_div = P.ScatterDiv() + >>> output = scatter_div(input_x, indices, updates) + [[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]] + """ class SpaceToDepth(PrimitiveWithInfer): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index b3f4e50c386..98a7b766e7b 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -185,6 +185,19 @@ class HistogramSummaryNet(nn.Cell): return out +class ScatterUpdate(nn.Cell): + """ScatterUpdate net definition""" + + def __init__(self, ref_shape, dtype=np.float32, use_locking=False): + super(ScatterUpdate, self).__init__() + self.scatter_update = P.ScatterUpdate(use_locking) + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_update(self.ref, indices, updates) + return out + + class ScatterMax(nn.Cell): """ScatterMax net definition""" @@ -237,6 +250,32 @@ class ScatterSub(nn.Cell): return out +class ScatterMul(nn.Cell): + """ScatterMul net definition""" + + def __init__(self, ref_shape, dtype=np.float32, use_locking=False): + super(ScatterMul, self).__init__() + self.scatter_mul = P.ScatterMul(use_locking) + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_mul(self.ref, indices, updates) + return out + + +class ScatterDiv(nn.Cell): + """ScatterDiv net definition""" + + def __init__(self, ref_shape, dtype=np.float32, use_locking=False): + super(ScatterDiv, self).__init__() + self.scatter_div = P.ScatterDiv(use_locking) + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)*10), name="ref") + + def construct(self, indices, updates): + out = self.scatter_div(self.ref, indices, updates) + return out + + class ApplyFtrlNet(nn.Cell): def __init__(self): super(ApplyFtrlNet, self).__init__() @@ -1861,6 +1900,11 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), Tensor(np.ones([2, 2, 3], np.int32))), 'skip': ['backward']}), + ('ScatterUpdate', { + 'block': ScatterUpdate((6,)), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float32))), + 'skip': ['backward']}), ('ScatterAddUseLocking', { 'block': ScatterAdd((6,), use_locking=True), 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), @@ -1902,6 +1946,73 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), Tensor(np.array([2, 3, 4], np.uint8))), 'skip': ['backward']}), + ('ScatterMulUseLocking', { + 'block': ScatterMul((6,), use_locking=True), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), + ('ScatterMulScalar', { + 'block': ScatterMul((6,)), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), + ('ScatterMul2d', { + 'block': ScatterMul((3, 4)), + 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), + Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], + [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), + 'skip': ['backward']}), + ('ScatterMulF16', { + 'block': ScatterMul((6,), np.float16), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float16))), + 'skip': ['backward']}), + ('ScatterMulI8', { + 'block': ScatterMul((6,), np.int8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int8))), + 'skip': ['backward']}), + ('ScatterMulI32', { + 'block': ScatterMul((6,), np.int32), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int32))), + 'skip': ['backward']}), + ('ScatterMulU8', { + 'block': ScatterMul((6,), np.uint8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.uint8))), + 'skip': ['backward']}), + ('ScatterDivUseLocking', { + 'block': ScatterDiv((6,), use_locking=True), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), + ('ScatterDivScalar', { + 'block': ScatterDiv((6,)), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), + ('ScatterDiv2d', { + 'block': ScatterDiv((3, 4)), + 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), + Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], + [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), + 'skip': ['backward']}), + ('ScatterDivF16', { + 'block': ScatterDiv((6,), np.float16), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float16))), + 'skip': ['backward']}), + ('ScatterDivI8', { + 'block': ScatterDiv((6,), np.int8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int8))), + 'skip': ['backward']}), + ('ScatterDivU8', { + 'block': ScatterDiv((6,), np.uint8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.uint8))), + 'skip': ['backward']}), ('ScatterSubUseLocking', { 'block': ScatterSub((6,), use_locking=True), 'desc_inputs': (Tensor(np.array([2], np.int32)),