!33367 add function & tensor op for TensorScatterAdd

Merge pull request !33367 from yuchaojie/op_dev
This commit is contained in:
i-robot 2022-05-05 01:32:34 +00:00 committed by Gitee
commit 9c860def02
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 367 additions and 256 deletions

View File

@ -251,6 +251,7 @@ Array操作
mindspore.ops.select
mindspore.ops.shape
mindspore.ops.size
mindspore.ops.tensor_scatter_add
mindspore.ops.tile
mindspore.ops.transpose
mindspore.ops.unique
@ -277,8 +278,6 @@ Array操作
- Refer to :class:`mindspore.ops.Stack`.
* - mindspore.ops.strided_slice
- Refer to :class:`mindspore.ops.StridedSlice`.
* - mindspore.ops.tensor_scatter_add
- Refer to :class:`mindspore.ops.TensorScatterAdd`.
* - mindspore.ops.tensor_scatter_div
- Refer to :class:`mindspore.ops.TensorScatterDiv`.
* - mindspore.ops.tensor_scatter_update

View File

@ -686,6 +686,24 @@ mindspore.Tensor
- **ValueError** - `axis` 超出范围,或 `mode` 被设置为'raise'、'wrap'和'clip'以外的值。
.. py:method:: tensor_scatter_add(indices, updates)
根据指定的更新值和输入索引通过相加运算更新本Tensor的值。当同一索引有不同值时更新的结果将是所有值的总和。
**参数:**
- **indices** (Tensor) - Tensor的索引数据类型为int32或int64的。其rank必须至少为2。
- **updates** (Tensor) - 指定与本Tensor相加操作的Tensor其数据类型与该Tensor相同。updates.shape应等于indices.shape[:-1] + self.shape[indices.shape[-1]:]。
**返回:**
Tensorshape和数据类型与原Tensor相同。
**异常:**
- **TypeError** - `indices` 的数据类型既不是int32也不是int64。
- **ValueError** - Tensor的shape长度小于 `indices` 的shape的最后一个维度。
.. py:method:: tensor_scatter_div(indices, updates)
根据指定的索引, 通过除法进行计算, 将输出赋值到输出Tensor中。

View File

@ -5,22 +5,4 @@
根据指定的更新值和输入索引通过相加运算更新输入Tensor的值。当同一索引有不同值时更新的结果将是所有值的总和。此操作与 :class:`mindspore.ops.ScatterNdAdd` 类似只是更新后的结果是通过算子output返回而不是直接原地更新input。
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息请参见使用用例。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
**输入:**
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
- **indices** (Tensor) - 输入Tensor的索引数据类型为int32或int64的。其rank必须至少为2。
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
**输出:**
Tensorshape和数据类型与输入 `input_x` 相同。
**异常:**
- **TypeError** - `indices` 的数据类型既不是int32也不是int64。
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
更多参考详见 :func:`mindspore.ops.tensor_scatter_add`

View File

@ -0,0 +1,26 @@
mindspore.ops.tensor_scatter_add
================================
.. py:function:: mindspore.ops.tensor_scatter_add(input_x, indices, updates)
根据指定的更新值和输入索引通过相加运算更新输入Tensor的值。当同一索引有不同值时更新的结果将是所有值的总和。此操作与 :class:`mindspore.ops.ScatterNdAdd` 类似只是更新后的结果是通过算子output返回而不是直接原地更新input。
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息请参见使用用例。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
**参数:**
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
- **indices** (Tensor) - 输入Tensor的索引数据类型为int32或int64的。其rank必须至少为2。
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
**返回:**
Tensorshape和数据类型与输入 `input_x` 相同。
**异常:**
- **TypeError** - `indices` 的数据类型既不是int32也不是int64。
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。

View File

@ -251,6 +251,7 @@ Array Operation
mindspore.ops.select
mindspore.ops.shape
mindspore.ops.size
mindspore.ops.tensor_scatter_add
mindspore.ops.tensor_scatter_div
mindspore.ops.tile
mindspore.ops.transpose
@ -278,8 +279,6 @@ Array Operation
- Refer to :class:`mindspore.ops.Stack`.
* - mindspore.ops.strided_slice
- Refer to :class:`mindspore.ops.StridedSlice`.
* - mindspore.ops.tensor_scatter_add
- Refer to :class:`mindspore.ops.TensorScatterAdd`.
* - mindspore.ops.tensor_scatter_update
- Refer to :class:`mindspore.ops.TensorScatterUpdate`.
* - mindspore.ops.tensor_slice

View File

@ -30,213 +30,215 @@ namespace mindspore {
namespace pipeline {
BuiltInTypeMap &GetMethodMap() {
static BuiltInTypeMap method_map = {{kObjectTypeString,
{{"__bool__", std::string("str_bool")}, // C.str_bool
{"format", std::string("_format")}}},
{kMetaTypeNone,
{
{"__bool__", std::string("none_bool")} // C.none_bool
}},
{kObjectTypeFunction,
{{"__bool__", std::string("func_bool")}, // C.str_bool
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
{kNumberTypeBool,
{
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
{"__or__", prim::kPrimBoolOr}, // P.bool_or
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
{"__ne__", std::string("bool_ne")}, // C.bool_ne
{"__bool__", prim::kPrimIdentity} // P.identity
}},
{kNumberTypeInt,
{
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
{"__floor__", prim::kPrimIdentity}, // P.identity
{"__trunc__", prim::kPrimIdentity}, // P.identity
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
{"__bool__", std::string("int_bool")}, // C.int_bool
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
}},
{kNumberTypeUInt,
{
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
{"__floor__", prim::kPrimIdentity}, // P.identity,
{"__trunc__", prim::kPrimIdentity}, // P.identity,
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
{"__bool__", std::string("int_bool")}, // C.int_bool
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
}},
{kNumberTypeFloat,
{
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
{"__bool__", std::string("float_bool")}, // C.float_bool
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
}},
{kObjectTypeTuple,
{
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
}},
{kObjectTypeList,
{
{"__len__", prim::kPrimListLen}, // P.list_len,
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
{"__ms_next__", std::string("list_next")}, // C.list_next
{"append", std::string("list_append")}, // C.list_next
{"__bool__", std::string("list_bool")}, // C.list_bool
{"__ms_hasnext__", std::string("list_hasnext")},
{"insert", std::string("list_insert")},
}},
{kObjectTypeDictionary,
{
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"items", prim::kPrimDictItems}, // P.dict_items
{"__bool__", std::string("dict_bool")} // C.dict_bool
}},
{kObjectTypeTensorType,
{
{"all", std::string("all_")}, // C.reduce_all
{"any", std::string("any_")}, // C.reduce_any
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"abs", std::string("abs_")}, // C.abs_
{"mean", std::string("mean")}, // C.mean
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
{"view", std::string("view")}, // C.view
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", std::string("item")}, // P.item,
{"itemset", std::string("itemset")}, // P.itemset,
{"transpose", std::string("transpose")}, // P.transpose
{"flatten", std::string("flatten")}, // P.reshape(,-1)
{"reshape", std::string("reshape")}, // P.reshape()
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
{"swapaxes", std::string("swapaxes")}, // P.transpose()
{"narrow", std::string("narrow")}, // narrow()
{"masked_fill", std::string("masked_fill")}, // masked_fill()
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"astype", std::string("astype")}, // P.cast()
{"cumsum", std::string("cumsum")}, // P.cumsum()
{"copy", std::string("copy")}, // copy()
{"max", std::string("max")}, // P.reduce_max()
{"min", std::string("min")}, // P.reduce_min()
{"fill", std::string("fill")}, // P.fill()
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
{"clip", std::string("clip")}, // P.maximum(P.minimum)
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
{"argmax", std::string("argmax")}, // P.Argmax()
{"argmin", std::string("argmin")}, // P.Argmax()
{"resize", std::string("resize")}, // P.Reshape()
{"choose", std::string("choose")}, // P.Select()
{"diagonal", std::string("diagonal")}, // P.Eye()
{"searchsorted", std::string("searchsorted")}, // P.Select()
{"take", std::string("take")}, // P.GatherNd()
{"trace", std::string("trace")}, // P.Eye()
{"var", std::string("var")}, // P.ReduceSum
{"std", std::string("std")}, // P.ReduceSum
{"sum", std::string("sum")}, // P.ReduceSum
{"repeat", std::string("repeat")}, // C.repeat_elements
}},
{kObjectTypeRowTensorType,
{
{"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
}},
{kObjectTypeCSRTensorType,
{
{"astype", std::string("csr_astype")}, // C.csr_astype
{"abs", std::string("csr_abs")}, // C.csr_abs
{"sum", std::string("csr_sum")}, // C.csr_sum
{"mv", std::string("csr_mv")}, // C.csr_mv
{"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
}},
{kObjectTypeCOOTensorType,
{
{"astype", std::string("coo_astype")}, // C.coo_astype
{"abs", std::string("coo_abs")}, // C.coo_abs
{"to_tuple", std::string("coo_to_tuple")}, // C.coo_to_tuple
{"to_csr", std::string("coo_to_csr")}, // C.coo_to_csr
{"to_dense", std::string("coo_to_dense")}, // C.coo_to_dense
}},
{kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}};
static BuiltInTypeMap method_map = {
{kObjectTypeString,
{{"__bool__", std::string("str_bool")}, // C.str_bool
{"format", std::string("_format")}}},
{kMetaTypeNone,
{
{"__bool__", std::string("none_bool")} // C.none_bool
}},
{kObjectTypeFunction,
{{"__bool__", std::string("func_bool")}, // C.str_bool
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
{kNumberTypeBool,
{
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
{"__or__", prim::kPrimBoolOr}, // P.bool_or
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
{"__ne__", std::string("bool_ne")}, // C.bool_ne
{"__bool__", prim::kPrimIdentity} // P.identity
}},
{kNumberTypeInt,
{
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
{"__floor__", prim::kPrimIdentity}, // P.identity
{"__trunc__", prim::kPrimIdentity}, // P.identity
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
{"__bool__", std::string("int_bool")}, // C.int_bool
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
}},
{kNumberTypeUInt,
{
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
{"__floor__", prim::kPrimIdentity}, // P.identity,
{"__trunc__", prim::kPrimIdentity}, // P.identity,
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
{"__bool__", std::string("int_bool")}, // C.int_bool
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
}},
{kNumberTypeFloat,
{
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
{"__bool__", std::string("float_bool")}, // C.float_bool
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
}},
{kObjectTypeTuple,
{
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
}},
{kObjectTypeList,
{
{"__len__", prim::kPrimListLen}, // P.list_len,
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
{"__ms_next__", std::string("list_next")}, // C.list_next
{"append", std::string("list_append")}, // C.list_next
{"__bool__", std::string("list_bool")}, // C.list_bool
{"__ms_hasnext__", std::string("list_hasnext")},
{"insert", std::string("list_insert")},
}},
{kObjectTypeDictionary,
{
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"items", prim::kPrimDictItems}, // P.dict_items
{"__bool__", std::string("dict_bool")} // C.dict_bool
}},
{kObjectTypeTensorType,
{
{"all", std::string("all_")}, // C.reduce_all
{"any", std::string("any_")}, // C.reduce_any
{"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul
{"abs", std::string("abs_")}, // C.abs_
{"mean", std::string("mean")}, // C.mean
{"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
{"view", std::string("view")}, // C.view
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", std::string("item")}, // P.item,
{"itemset", std::string("itemset")}, // P.itemset,
{"transpose", std::string("transpose")}, // P.transpose
{"flatten", std::string("flatten")}, // P.reshape(,-1)
{"reshape", std::string("reshape")}, // P.reshape()
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
{"swapaxes", std::string("swapaxes")}, // P.transpose()
{"narrow", std::string("narrow")}, // narrow()
{"masked_fill", std::string("masked_fill")}, // masked_fill()
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"astype", std::string("astype")}, // P.cast()
{"cumsum", std::string("cumsum")}, // P.cumsum()
{"copy", std::string("copy")}, // copy()
{"max", std::string("max")}, // P.reduce_max()
{"min", std::string("min")}, // P.reduce_min()
{"fill", std::string("fill")}, // P.fill()
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
{"clip", std::string("clip")}, // P.maximum(P.minimum)
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
{"argmax", std::string("argmax")}, // P.Argmax()
{"argmin", std::string("argmin")}, // P.Argmax()
{"resize", std::string("resize")}, // P.Reshape()
{"choose", std::string("choose")}, // P.Select()
{"diagonal", std::string("diagonal")}, // P.Eye()
{"searchsorted", std::string("searchsorted")}, // P.Select()
{"take", std::string("take")}, // P.GatherNd()
{"tensor_scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd()
{"trace", std::string("trace")}, // P.Eye()
{"var", std::string("var")}, // P.ReduceSum
{"std", std::string("std")}, // P.ReduceSum
{"sum", std::string("sum")}, // P.ReduceSum
{"repeat", std::string("repeat")}, // C.repeat_elements
}},
{kObjectTypeRowTensorType,
{
{"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
}},
{kObjectTypeCSRTensorType,
{
{"astype", std::string("csr_astype")}, // C.csr_astype
{"abs", std::string("csr_abs")}, // C.csr_abs
{"sum", std::string("csr_sum")}, // C.csr_sum
{"mv", std::string("csr_mv")}, // C.csr_mv
{"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
}},
{kObjectTypeCOOTensorType,
{
{"astype", std::string("coo_astype")}, // C.coo_astype
{"abs", std::string("coo_abs")}, // C.coo_abs
{"to_tuple", std::string("coo_to_tuple")}, // C.coo_to_tuple
{"to_csr", std::string("coo_to_csr")}, // C.coo_to_csr
{"to_dense", std::string("coo_to_dense")}, // C.coo_to_dense
}},
{kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}};
return method_map;
}

View File

@ -1560,6 +1560,15 @@ def while_cond(x):
return x
def tensor_scatter_add(x, indices, updates):
"""
Creates a new tensor by adding the values from the positions in `x` indicated by
`indices`, with values from `updates`. When multiple values are given for the same
index, the updated result will be the sum of all values.
"""
return F.tensor_scatter_add(x, indices, updates)
def coo_to_csr(x):
"""convert coo to csr."""
row_indices = x.indices[:, 0]

View File

@ -1415,6 +1415,52 @@ class Tensor(Tensor_):
return reduce_(self, reduce_min(keepdims), cmp_fn=minimum(), axis=axis, keepdims=keepdims,
initial=initial, where=where)
def tensor_scatter_add(self, indices, updates):
"""
Creates a new tensor by adding the values from the positions in self tensor indicated by
`indices`, with values from `updates`. When multiple values are given for the same
index, the updated result will be the sum of all values. This operation is almost
equivalent to using ScatterNdAdd, except that the updates are applied on output `Tensor`
instead of input `Parameter`.
The last axis of `indices` is the depth of each index vectors. For each index vector,
there must be a corresponding value in `updates`. The shape of `updates` should be
equal to the shape of `self[indices]`. For more details, see use cases.
Note:
If some values of the `indices` are out of bound, instead of raising an index error,
the corresponding `updates` will not be updated to self tensor.
Args:
indices (Tensor): The index of input tensor whose data type is int32 or int64.
The rank must be at least 2.
updates (Tensor): The tensor to update the input tensor, has the same type as self tensor,
and updates. Shape should be equal to indices.shape[:-1] + self.shape[indices.shape[-1]:].
Returns:
Tensor, has the same shape and type as self tensor.
Raises:
TypeError: If dtype of `indices` is neither int32 nor int64.
ValueError: If length of shape of self tensor is less than the last dimension of shape of `indices`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype('float32'))
>>> indices = Tensor(np.array([[0, 0], [0, 0]]).astype('int32'))
>>> updates = Tensor(np.array([1.0, 2.2]).astype('float32'))
>>> output = x.tensor_scatter_add(indices, updates)
>>> print(output)
[[ 3.1 0.3 3.6]
[ 0.4 0.5 -3.2]]
"""
self._init_check()
return tensor_operator_registry.get('tensor_scatter_add')()(self, indices, updates)
def fill(self, value):
"""
Fill the tensor with a scalar value.

View File

@ -23,7 +23,7 @@ from . import array_func, parameter_func, math_func
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
scatter_max, scatter_min)
tensor_scatter_add, scatter_max, scatter_min)
from .parameter_func import assign, assign_add, assign_sub, index_add
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,

View File

@ -948,6 +948,57 @@ def gather_nd(input_x, indices):
return gather_nd_(input_x, indices)
tensor_scatter_add_ = P.TensorScatterAdd()
def tensor_scatter_add(input_x, indices, updates):
"""
Creates a new tensor by adding the values from the positions in `input_x` indicated by
`indices`, with values from `updates`. When multiple values are given for the same
index, the updated result will be the sum of all values. This operation is almost
equivalent to using ScatterNdAdd, except that the updates are applied on output `Tensor`
instead of input `Parameter`.
The last axis of `indices` is the depth of each index vectors. For each index vector,
there must be a corresponding value in `updates`. The shape of `updates` should be
equal to the shape of `input_x[indices]`. For more details, see use cases.
Note:
If some values of the `indices` are out of bound, instead of raising an index error,
the corresponding `updates` will not be updated to `input_x`.
Args:
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
The rank must be at least 2.
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
and updates. Shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
Returns:
Tensor, has the same shape and type as `input_x`.
Raises:
TypeError: If dtype of `indices` is neither int32 nor int64.
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, nn
>>> from mindspore import ops
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
>>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> output = ops.tensor_scatter_add(input_x, indices, updates)
>>> print(output)
[[ 3.1 0.3 3.6]
[ 0.4 0.5 -3.2]]
"""
return tensor_scatter_add_(input_x, indices, updates)
##############################
# Type Conversion Functions.
##############################
@ -1132,6 +1183,7 @@ __all__ = [
'expand_dims',
'transpose',
'scatter_nd',
'tensor_scatter_add',
'gather',
'gather_d',
'gather_nd',

View File

@ -71,7 +71,6 @@ scatter_update = P.ScatterUpdate()
tensor_scatter_update = P.TensorScatterUpdate()
tensor_scatter_min = P.TensorScatterMin()
tensor_scatter_max = P.TensorScatterMax()
tensor_scatter_add = P.TensorScatterAdd()
tensor_scatter_sub = P.TensorScatterSub()
tensor_scatter_mul = P.TensorScatterMul()
tensor_scatter_div = P.TensorScatterDiv()
@ -900,7 +899,6 @@ logical_or = P.LogicalOr()
logical_not = P.LogicalNot()
cumsum = P.CumSum()
cumprod = P.CumProd()
tensor_scatter_add = P.TensorScatterAdd()
array_to_scalar = Primitive('array_to_scalar')
is_ = Primitive("is_")
is_not = Primitive("is_not")
@ -1022,6 +1020,6 @@ tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
tensor_operator_registry.register('tensor_scatter_add', tensor_scatter_add)
tensor_operator_registry.register('tensor_scatter_add', P.TensorScatterAdd)
__all__ = [name for name in dir() if name[0] != "_"]
__all__.remove('Primitive')

View File

@ -6951,30 +6951,10 @@ class TensorScatterAdd(_TensorScatterOp):
equivalent to using ScatterNdAdd, except that the updates are applied on output `Tensor`
instead of input `Parameter`.
The last axis of `indices` is the depth of each index vectors. For each index vector,
there must be a corresponding value in `updates`. The shape of `updates` should be
equal to the shape of `input_x[indices]`. For more details, see use cases.
Note:
If some values of the `indices` are out of bound, instead of raising an index error,
the corresponding `updates` will not be updated to `input_x`.
Inputs:
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
The rank must be at least 2.
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Raises:
TypeError: If dtype of `indices` is neither int32 nor int64.
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
Refer to :func:`mindspore.ops.tensor_scatter_add` for more detail.
Supported Platforms:
``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)