add op scatter add vm

This commit is contained in:
zhaozhenlong 2020-05-25 17:29:44 +08:00
parent 10076ffe1a
commit 55d1927534
5 changed files with 119 additions and 7 deletions

View File

@ -198,3 +198,4 @@ from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_add_op_info = TBERegOp("ScatterAdd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_add.so") \
.compute_cost(10) \
.kernel_name("scatter_add") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_add_op_info)
def _scatter_add_tbe():
"""ScatterAdd TBE register"""
return

View File

@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
SameTypeShape, ScatterMax, ScatterUpdate,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile,
@ -190,6 +190,7 @@ __all__ = [
'BoundingBoxEncode',
'BoundingBoxDecode',
'L2Normalize',
'ScatterAdd',
'ScatterNd',
'ScatterMax',
'ResizeNearestNeighbor',

View File

@ -2145,6 +2145,12 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
class ScatterMax(PrimitiveWithInfer):
"""
@ -2158,8 +2164,8 @@ class ScatterMax(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do max operation whose data type should be int.
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.
- **updates** (Tensor) - The tensor doing the maximum operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Tensor, has the same shape and data type as `input_x`.
@ -2180,10 +2186,7 @@ class ScatterMax(PrimitiveWithInfer):
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
@ -2193,6 +2196,49 @@ class ScatterMax(PrimitiveWithInfer):
return x_dtype
class ScatterAdd(PrimitiveWithInfer):
"""
Update the value of the input tensor through the add operation.
Using given values to update tensor value through the add operation, along with the input indices.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be int.
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Tensor, has the same shape and data type as `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
>>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> scatter_add = P.ScatterAdd()
>>> output = scatter_add(input_x, indices, updates)
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterAdd"""
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class SpaceToDepth(PrimitiveWithInfer):
r"""
Rearrange blocks of spatial data into depth.

View File

@ -196,6 +196,19 @@ class ScatterMax(nn.Cell):
return out
class ScatterAdd(nn.Cell):
"""ScatterAdd net definition"""
def __init__(self, ref_shape):
super(ScatterAdd, self).__init__()
self.scatter_add = P.ScatterAdd()
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref")
def construct(self, indices, updates):
out = self.scatter_add(self.ref, indices, updates)
return out
class ApplyFtrlNet(nn.Cell):
def __init__(self):
super(ApplyFtrlNet, self).__init__()
@ -1257,6 +1270,17 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float32) * 99)),
'skip': ['backward']}),
('ScatterAdd', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd2d', {
'block': ScatterAdd((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}),
('SmoothL1Loss', {
'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]],