forked from mindspore-Ecosystem/mindspore
add op scatter add vm
This commit is contained in:
parent
10076ffe1a
commit
55d1927534
|
@ -198,3 +198,4 @@ from .apply_rms_prop import _apply_rms_prop_tbe
|
|||
from .cumprod import _cumprop_tbe
|
||||
from .reduce_prod import _reduce_prod_tbe
|
||||
from .flatten_grad import _flatten_grad_tbe
|
||||
from .scatter_add import _scatter_add_tbe
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_add_op_info = TBERegOp("ScatterAdd") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_add") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_add_op_info)
|
||||
def _scatter_add_tbe():
|
||||
"""ScatterAdd TBE register"""
|
||||
return
|
|
@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
|
||||
SameTypeShape, ScatterMax, ScatterUpdate,
|
||||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Squeeze, StridedSlice, Tile,
|
||||
|
@ -190,6 +190,7 @@ __all__ = [
|
|||
'BoundingBoxEncode',
|
||||
'BoundingBoxDecode',
|
||||
'L2Normalize',
|
||||
'ScatterAdd',
|
||||
'ScatterNd',
|
||||
'ScatterMax',
|
||||
'ResizeNearestNeighbor',
|
||||
|
|
|
@ -2145,6 +2145,12 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
|
||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
||||
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
||||
|
||||
|
||||
class ScatterMax(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -2158,8 +2164,8 @@ class ScatterMax(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do max operation whose data type should be int.
|
||||
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
|
||||
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.
|
||||
- **updates** (Tensor) - The tensor doing the maximum operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and data type as `input_x`.
|
||||
|
@ -2180,10 +2186,7 @@ class ScatterMax(PrimitiveWithInfer):
|
|||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
|
||||
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
|
||||
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
|
@ -2193,6 +2196,49 @@ class ScatterMax(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class ScatterAdd(PrimitiveWithInfer):
|
||||
"""
|
||||
Update the value of the input tensor through the add operation.
|
||||
|
||||
Using given values to update tensor value through the add operation, along with the input indices.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do add operation whose data type should be int.
|
||||
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and data type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
||||
>>> scatter_add = P.ScatterAdd()
|
||||
>>> output = scatter_add(input_x, indices, updates)
|
||||
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Init ScatterAdd"""
|
||||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||
args = {'x': x_dtype, 'updates': updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SpaceToDepth(PrimitiveWithInfer):
|
||||
r"""
|
||||
Rearrange blocks of spatial data into depth.
|
||||
|
|
|
@ -196,6 +196,19 @@ class ScatterMax(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ScatterAdd(nn.Cell):
|
||||
"""ScatterAdd net definition"""
|
||||
|
||||
def __init__(self, ref_shape):
|
||||
super(ScatterAdd, self).__init__()
|
||||
self.scatter_add = P.ScatterAdd()
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_add(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyFtrlNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyFtrlNet, self).__init__()
|
||||
|
@ -1257,6 +1270,17 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
|
||||
Tensor(np.ones([2, 2, 3], np.float32) * 99)),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAdd', {
|
||||
'block': ScatterAdd((6,)),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAdd2d', {
|
||||
'block': ScatterAdd((3, 4)),
|
||||
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
|
||||
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('SmoothL1Loss', {
|
||||
'block': P.SmoothL1Loss(),
|
||||
'desc_inputs': [[256, 4], [256, 4]],
|
||||
|
|
Loading…
Reference in New Issue