forked from OSSInnovation/mindspore
!3175 add ScatterNdAdd ScatterNdSub ScatterNonAliasingAdd ops
Merge pull request !3175 from fangzehua/scatter_add_vm
This commit is contained in:
commit
37cc6e2628
|
@ -37,6 +37,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"re_lu6", "relu6"},
|
||||
{"re_lu6_grad", "relu6_grad"},
|
||||
{"re_lu", "relu"},
|
||||
{"reverse_v2", "reverse_v2_d"},
|
||||
{"re_luv2", "relu_v2"},
|
||||
{"p_re_lu", "prelu"},
|
||||
{"p_re_lu_grad", "prelu_grad"},
|
||||
|
|
|
@ -377,6 +377,18 @@ def get_bprop_pack(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ReverseV2)
|
||||
def get_bprop_reverse_v2(self):
|
||||
"""Generate bprop for ReverseV2"""
|
||||
axis = self.axis
|
||||
|
||||
def bprop(x, out, dout):
|
||||
reverse_grad = P.ReverseV2(axis)
|
||||
dx = reverse_grad(dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(P.Unpack)
|
||||
def get_bprop_unpack(self):
|
||||
"""Generate bprop for Unpack"""
|
||||
|
@ -495,6 +507,16 @@ def get_bprop_scatter_nd_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ScatterNonAliasingAdd)
|
||||
def get_bprop_scatter_non_aliasing_add_update(self):
|
||||
"""Generate bprop for ScatterNonAliasingAdd"""
|
||||
op = P.GatherNd()
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
return dout, zeros_like(indices), op(dout, indices)
|
||||
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(P.TensorScatterUpdate)
|
||||
def get_bprop_tensor_scatter_update(self):
|
||||
"""Generate bprop for TensorScatterUpdate"""
|
||||
|
@ -509,6 +531,7 @@ def get_bprop_tensor_scatter_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
|
||||
@bprop_getters.register(P.ScatterMax)
|
||||
def get_bprop_scatter_max(self):
|
||||
"""Generate bprop for ScatterMax"""
|
||||
|
|
|
@ -81,6 +81,9 @@ from .sub import _sub_tbe
|
|||
from .reduce_mean_d import _reduce_mean_d_tbe
|
||||
from .scatter_nd import _scatter_nd_tbe
|
||||
from .scatter_nd_d import _scatter_nd_d_tbe
|
||||
from .scatter_nd_add import _scatter_nd_add_tbe
|
||||
from .scatter_nd_sub import _scatter_nd_sub_tbe
|
||||
from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe
|
||||
from .reduce_mean import _reduce_mean_tbe
|
||||
from .tile import _tile_tbe
|
||||
from .atomic_addr_clean import _atomic_addr_clean_tbe
|
||||
|
@ -93,6 +96,8 @@ from .bn_training_update_grad import _bn_training_update_grad_tbe
|
|||
from .bn_infer import _bn_infer_tbe
|
||||
from .bn_infer_grad import _bn_infer_grad_tbe
|
||||
from .reciprocal import _reciprocal_tbe
|
||||
from .reverse_v2_d import _reverse_v2_d_tbe
|
||||
from .rint import _rint_tbe
|
||||
from .strided_slice_d import _strided_slice_d_tbe
|
||||
from .strided_slice_grad_d import _strided_slice_grad_d_tbe
|
||||
from .split_d import _split_d_tbe
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReverseV2D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reverse_v2_d_op_info = TBERegOp("ReverseV2") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reverse_v2_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reverse_v2_d") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.attr("axis", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reverse_v2_d_op_info)
|
||||
def _reverse_v2_d_tbe():
|
||||
"""ReverseV2D TBE register"""
|
||||
return
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Rint op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
rint_op_info = TBERegOp("Rint") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("rint.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("rint") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(rint_op_info)
|
||||
def _rint_tbe():
|
||||
"""Rint TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterNdAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_nd_add_op_info = TBERegOp("ScatterNdAdd") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_nd_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_nd_add") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_nd_add_op_info)
|
||||
def _scatter_nd_add_tbe():
|
||||
"""ScatterNdAdd TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterNdSub op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_nd_sub_op_info = TBERegOp("ScatterNdSub") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_nd_sub.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_nd_sub") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_nd_sub_op_info)
|
||||
def _scatter_nd_sub_tbe():
|
||||
"""ScatterNdSub TBE register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterNonAliasingAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_non_aliasing_add_op_info = TBERegOp("ScatterNonAliasingAdd") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_non_aliasing_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_non_aliasing_add") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_non_aliasing_add_op_info)
|
||||
def _scatter_non_aliasing_add_tbe():
|
||||
"""ScatterNonAliasingAdd TBE register"""
|
||||
return
|
|
@ -28,6 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split, TransShape, ParallelConcat,
|
||||
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
|
@ -233,6 +234,11 @@ __all__ = [
|
|||
'ScatterNd',
|
||||
'ScatterMax',
|
||||
'ScatterMin',
|
||||
'ScatterNdAdd',
|
||||
'ScatterNdSub',
|
||||
'ScatterNonAliasingAdd',
|
||||
'ReverseV2',
|
||||
'Rint',
|
||||
'ResizeNearestNeighbor',
|
||||
'HistogramFixedWidth',
|
||||
'Pad',
|
||||
|
|
|
@ -47,8 +47,8 @@ class _ScatterOp(PrimitiveWithInfer):
|
|||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
@staticmethod
|
||||
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
|
||||
|
||||
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
|
||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
||||
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||
|
@ -61,7 +61,7 @@ class _ScatterOp(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
|
@ -71,6 +71,19 @@ class _ScatterOp(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class _ScatterNdOp(_ScatterOp):
|
||||
"""
|
||||
Define _ScatterNd operators
|
||||
"""
|
||||
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
|
||||
validator.check('the dimension of x', len(x_shape),
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
|
||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or updates_shape = "
|
||||
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
|
||||
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
||||
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
||||
|
@ -1759,6 +1772,75 @@ class Slice(PrimitiveWithInfer):
|
|||
'value': None}
|
||||
|
||||
|
||||
class ReverseV2(PrimitiveWithInfer):
|
||||
"""
|
||||
Reverse specific dimensions of a tensor.
|
||||
|
||||
Args:
|
||||
axis (Union[tuple(int), list(int)): The indices of the dimensions to reverse.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
|
||||
>>> op = P.ReverseV2(axis=[1])
|
||||
>>> output = op(input_x)
|
||||
[[4, 3, 2, 1], [8, 7, 6, 5]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis):
|
||||
validator.check_value_type('axis', axis, [list, tuple], self.name)
|
||||
for i, each in enumerate(axis):
|
||||
validator.check_value_type(f'axis[{i}]', each, [int], self.name)
|
||||
self.axis = axis
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
dim = len(x_shape)
|
||||
for i, each in enumerate(self.axis):
|
||||
validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Rint(PrimitiveWithInfer):
|
||||
"""
|
||||
Return element-wise integer closest to x.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor, which must be one of the following types:
|
||||
float16, float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)
|
||||
>>> op = P.Rint()
|
||||
>>> output = op(input_x)
|
||||
[-2., 0., 2., 2.]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class Select(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
||||
|
@ -2404,7 +2486,7 @@ class ScatterUpdate(_ScatterOp):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class ScatterNdUpdate(PrimitiveWithInfer):
|
||||
class ScatterNdUpdate(_ScatterNdOp):
|
||||
"""
|
||||
Update tensor value by using input indices and value.
|
||||
|
||||
|
@ -2429,11 +2511,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
>>> op = P.ScatterNdUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
|
@ -2441,13 +2519,6 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
validator.check('the dimension of x', len(x_shape),
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
||||
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
|
@ -2635,6 +2706,101 @@ class ScatterDiv(_ScatterOp):
|
|||
"""
|
||||
|
||||
|
||||
class ScatterNdAdd(_ScatterNdOp):
|
||||
"""
|
||||
Applies sparse addition to individual values or slices in a Tensor.
|
||||
|
||||
Using given values to update tensor value through the add operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
|
||||
|
||||
Outputs:
|
||||
Parameter, the updated `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
|
||||
>>> scatter_nd_add = P.ScatterNdAdd()
|
||||
>>> output = scatter_nd_add(input_x, indices, updates)
|
||||
[1, 10, 9, 4, 12, 6, 7, 17]
|
||||
"""
|
||||
|
||||
|
||||
class ScatterNdSub(_ScatterNdOp):
|
||||
"""
|
||||
Applies sparse subtraction to individual values or slices in a Tensor.
|
||||
|
||||
Using given values to update tensor value through the sub operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the sub operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
|
||||
|
||||
Outputs:
|
||||
Parameter, the updated `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
|
||||
>>> scatter_nd_sub = P.ScatterNdSub()
|
||||
>>> output = scatter_nd_sub(input_x, indices, updates)
|
||||
[1, -6, -3, 4, -2, 6, 7, -1]
|
||||
"""
|
||||
|
||||
|
||||
class ScatterNonAliasingAdd(_ScatterNdOp):
|
||||
"""
|
||||
Applies sparse addition to input using individual values or slices.
|
||||
|
||||
Using given values to update tensor value through the add operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
|
||||
|
||||
Outputs:
|
||||
Parameter, the updated `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
|
||||
>>> scatter_non_aliasing_add = P.ScatterNonAliasingAdd()
|
||||
>>> output = scatter_non_aliasing_add(input_x, indices, updates)
|
||||
[1, 10, 9, 4, 12, 6, 7, 17]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Init ScatterNonAliasingAdd"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SpaceToDepth(PrimitiveWithInfer):
|
||||
r"""
|
||||
Rearrange blocks of spatial data into depth.
|
||||
|
|
|
@ -237,6 +237,44 @@ class ScatterAdd(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ScatterNonAliasingAdd(nn.Cell):
|
||||
"""ScatterNonAliasingAdd net definition"""
|
||||
|
||||
def __init__(self, ref_shape, dtype=np.float32):
|
||||
super(ScatterNonAliasingAdd, self).__init__()
|
||||
self.scatter_no_aliasing_add = P.ScatterNonAliasingAdd()
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_no_aliasing_add(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ScatterNdSub(nn.Cell):
|
||||
"""ScatterNdSub net definition"""
|
||||
|
||||
def __init__(self, ref_shape, dtype=np.float32):
|
||||
super(ScatterNdSub, self).__init__()
|
||||
self.scatter_nd_sub = P.ScatterNdSub()
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_nd_sub(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
class ScatterNdAdd(nn.Cell):
|
||||
"""ScatterNdAdd net definition"""
|
||||
|
||||
def __init__(self, ref_shape, dtype=np.float32):
|
||||
super(ScatterNdAdd, self).__init__()
|
||||
self.scatter_nd_add = P.ScatterNdAdd()
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_nd_add(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ScatterSub(nn.Cell):
|
||||
"""ScatterSub net definition"""
|
||||
|
||||
|
@ -1811,6 +1849,14 @@ test_case_array_ops = [
|
|||
'desc_const': [(2, 1, 1, 2)],
|
||||
'desc_inputs': [[2, 2, 2]],
|
||||
'desc_bprop': [[2, 2, 2, 4]]}),
|
||||
('ReverseV2', {
|
||||
'block': P.ReverseV2(axis=[1]),
|
||||
'desc_inputs': [(Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)))],
|
||||
'desc_bprop': [(Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)))]}),
|
||||
('Rint', {
|
||||
'block': P.Rint(),
|
||||
'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))],
|
||||
'skip': ['backward']}),
|
||||
('ConcatV2_0', {
|
||||
'block': P.Concat(),
|
||||
'desc_inputs': [
|
||||
|
@ -2074,6 +2120,21 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterNonAliasingAdd_1d', {
|
||||
'block': ScatterNonAliasingAdd((8,)),
|
||||
'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterNdAdd', {
|
||||
'block': ScatterNdAdd((8,)),
|
||||
'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterNdSub', {
|
||||
'block': ScatterNdAdd((8,)),
|
||||
'desc_inputs': (Tensor(np.array([[2], [3], [4], [5]], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0, 8.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAdd', {
|
||||
'block': ScatterAdd((6,)),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
|
|
Loading…
Reference in New Issue