!33367 add function & tensor op for TensorScatterAdd
Merge pull request !33367 from yuchaojie/op_dev
This commit is contained in:
commit
9c860def02
|
@ -251,6 +251,7 @@ Array操作
|
|||
mindspore.ops.select
|
||||
mindspore.ops.shape
|
||||
mindspore.ops.size
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unique
|
||||
|
@ -277,8 +278,6 @@ Array操作
|
|||
- Refer to :class:`mindspore.ops.Stack`.
|
||||
* - mindspore.ops.strided_slice
|
||||
- Refer to :class:`mindspore.ops.StridedSlice`.
|
||||
* - mindspore.ops.tensor_scatter_add
|
||||
- Refer to :class:`mindspore.ops.TensorScatterAdd`.
|
||||
* - mindspore.ops.tensor_scatter_div
|
||||
- Refer to :class:`mindspore.ops.TensorScatterDiv`.
|
||||
* - mindspore.ops.tensor_scatter_update
|
||||
|
|
|
@ -686,6 +686,24 @@ mindspore.Tensor
|
|||
|
||||
- **ValueError** - `axis` 超出范围,或 `mode` 被设置为'raise'、'wrap'和'clip'以外的值。
|
||||
|
||||
.. py:method:: tensor_scatter_add(indices, updates)
|
||||
|
||||
根据指定的更新值和输入索引,通过相加运算更新本Tensor的值。当同一索引有不同值时,更新的结果将是所有值的总和。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **indices** (Tensor) - Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||
- **updates** (Tensor) - 指定与本Tensor相加操作的Tensor,其数据类型与该Tensor相同。updates.shape应等于indices.shape[:-1] + self.shape[indices.shape[-1]:]。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape和数据类型与原Tensor相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - Tensor的shape长度小于 `indices` 的shape的最后一个维度。
|
||||
|
||||
.. py:method:: tensor_scatter_div(indices, updates)
|
||||
|
||||
根据指定的索引, 通过除法进行计算, 将输出赋值到输出Tensor中。
|
||||
|
|
|
@ -5,22 +5,4 @@
|
|||
|
||||
根据指定的更新值和输入索引,通过相加运算更新输入Tensor的值。当同一索引有不同值时,更新的结果将是所有值的总和。此操作与 :class:`mindspore.ops.ScatterNdAdd` 类似,只是更新后的结果是通过算子output返回,而不是直接原地更新input。
|
||||
|
||||
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息,请参见使用用例。
|
||||
|
||||
.. note::
|
||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
|
||||
- **indices** (Tensor) - 输入Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor,其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,shape和数据类型与输入 `input_x` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
|
||||
更多参考详见 :func:`mindspore.ops.tensor_scatter_add`。
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
mindspore.ops.tensor_scatter_add
|
||||
================================
|
||||
|
||||
.. py:function:: mindspore.ops.tensor_scatter_add(input_x, indices, updates)
|
||||
|
||||
根据指定的更新值和输入索引,通过相加运算更新输入Tensor的值。当同一索引有不同值时,更新的结果将是所有值的总和。此操作与 :class:`mindspore.ops.ScatterNdAdd` 类似,只是更新后的结果是通过算子output返回,而不是直接原地更新input。
|
||||
|
||||
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息,请参见使用用例。
|
||||
|
||||
.. note::
|
||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
|
||||
- **indices** (Tensor) - 输入Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor,其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape和数据类型与输入 `input_x` 相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
|
|
@ -251,6 +251,7 @@ Array Operation
|
|||
mindspore.ops.select
|
||||
mindspore.ops.shape
|
||||
mindspore.ops.size
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.tensor_scatter_div
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.transpose
|
||||
|
@ -278,8 +279,6 @@ Array Operation
|
|||
- Refer to :class:`mindspore.ops.Stack`.
|
||||
* - mindspore.ops.strided_slice
|
||||
- Refer to :class:`mindspore.ops.StridedSlice`.
|
||||
* - mindspore.ops.tensor_scatter_add
|
||||
- Refer to :class:`mindspore.ops.TensorScatterAdd`.
|
||||
* - mindspore.ops.tensor_scatter_update
|
||||
- Refer to :class:`mindspore.ops.TensorScatterUpdate`.
|
||||
* - mindspore.ops.tensor_slice
|
||||
|
|
|
@ -30,213 +30,215 @@ namespace mindspore {
|
|||
namespace pipeline {
|
||||
|
||||
BuiltInTypeMap &GetMethodMap() {
|
||||
static BuiltInTypeMap method_map = {{kObjectTypeString,
|
||||
{{"__bool__", std::string("str_bool")}, // C.str_bool
|
||||
{"format", std::string("_format")}}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kObjectTypeFunction,
|
||||
{{"__bool__", std::string("func_bool")}, // C.str_bool
|
||||
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||
}},
|
||||
{kNumberTypeInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||
}},
|
||||
{kNumberTypeUInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kNumberTypeFloat,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kObjectTypeTuple,
|
||||
{
|
||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||
{"append", std::string("list_append")}, // C.list_next
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
{"insert", std::string("list_insert")},
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
|
||||
{"items", prim::kPrimDictItems}, // P.dict_items
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"all", std::string("all_")}, // C.reduce_all
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"abs", std::string("abs_")}, // C.abs_
|
||||
{"mean", std::string("mean")}, // C.mean
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
|
||||
{"view", std::string("view")}, // C.view
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", std::string("item")}, // P.item,
|
||||
{"itemset", std::string("itemset")}, // P.itemset,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"flatten", std::string("flatten")}, // P.reshape(,-1)
|
||||
{"reshape", std::string("reshape")}, // P.reshape()
|
||||
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
|
||||
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
|
||||
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
|
||||
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
|
||||
{"swapaxes", std::string("swapaxes")}, // P.transpose()
|
||||
{"narrow", std::string("narrow")}, // narrow()
|
||||
{"masked_fill", std::string("masked_fill")}, // masked_fill()
|
||||
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"cumsum", std::string("cumsum")}, // P.cumsum()
|
||||
{"copy", std::string("copy")}, // copy()
|
||||
{"max", std::string("max")}, // P.reduce_max()
|
||||
{"min", std::string("min")}, // P.reduce_min()
|
||||
{"fill", std::string("fill")}, // P.fill()
|
||||
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
|
||||
{"clip", std::string("clip")}, // P.maximum(P.minimum)
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
{"argmax", std::string("argmax")}, // P.Argmax()
|
||||
{"argmin", std::string("argmin")}, // P.Argmax()
|
||||
{"resize", std::string("resize")}, // P.Reshape()
|
||||
{"choose", std::string("choose")}, // P.Select()
|
||||
{"diagonal", std::string("diagonal")}, // P.Eye()
|
||||
{"searchsorted", std::string("searchsorted")}, // P.Select()
|
||||
{"take", std::string("take")}, // P.GatherNd()
|
||||
{"trace", std::string("trace")}, // P.Eye()
|
||||
{"var", std::string("var")}, // P.ReduceSum
|
||||
{"std", std::string("std")}, // P.ReduceSum
|
||||
{"sum", std::string("sum")}, // P.ReduceSum
|
||||
{"repeat", std::string("repeat")}, // C.repeat_elements
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
{"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
|
||||
}},
|
||||
{kObjectTypeCSRTensorType,
|
||||
{
|
||||
{"astype", std::string("csr_astype")}, // C.csr_astype
|
||||
{"abs", std::string("csr_abs")}, // C.csr_abs
|
||||
{"sum", std::string("csr_sum")}, // C.csr_sum
|
||||
{"mv", std::string("csr_mv")}, // C.csr_mv
|
||||
{"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
|
||||
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
|
||||
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
|
||||
}},
|
||||
{kObjectTypeCOOTensorType,
|
||||
{
|
||||
{"astype", std::string("coo_astype")}, // C.coo_astype
|
||||
{"abs", std::string("coo_abs")}, // C.coo_abs
|
||||
{"to_tuple", std::string("coo_to_tuple")}, // C.coo_to_tuple
|
||||
{"to_csr", std::string("coo_to_csr")}, // C.coo_to_csr
|
||||
{"to_dense", std::string("coo_to_dense")}, // C.coo_to_dense
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
static BuiltInTypeMap method_map = {
|
||||
{kObjectTypeString,
|
||||
{{"__bool__", std::string("str_bool")}, // C.str_bool
|
||||
{"format", std::string("_format")}}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kObjectTypeFunction,
|
||||
{{"__bool__", std::string("func_bool")}, // C.str_bool
|
||||
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
{"__or__", prim::kPrimBoolOr}, // P.bool_or
|
||||
{"__eq__", prim::kPrimBoolEq}, // P.bool_eq
|
||||
{"__ne__", std::string("bool_ne")}, // C.bool_ne
|
||||
{"__bool__", prim::kPrimIdentity} // P.identity
|
||||
}},
|
||||
{kNumberTypeInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
|
||||
{"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
|
||||
}},
|
||||
{kNumberTypeUInt,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__truediv__", std::string("int_truediv")}, // C.int_truediv
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__trunc__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("int_bool")}, // C.int_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kNumberTypeFloat,
|
||||
{
|
||||
{"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
|
||||
{"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
|
||||
{"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
|
||||
{"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
|
||||
{"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
|
||||
{"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
|
||||
{"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
|
||||
{"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
|
||||
{"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
|
||||
{"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
|
||||
{"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
|
||||
{"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
|
||||
{"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
|
||||
{"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
|
||||
{"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
|
||||
{"__le__", prim::kPrimScalarLe}, // P.scalar_le,
|
||||
{"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
|
||||
{"__bool__", std::string("float_bool")}, // C.float_bool
|
||||
{"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
|
||||
}},
|
||||
{kObjectTypeTuple,
|
||||
{
|
||||
{"__len__", prim::kPrimTupleLen}, // P.tuple_len,
|
||||
{"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
|
||||
{"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity,
|
||||
{"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
|
||||
{"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
|
||||
{"__bool__", std::string("tuple_bool")} // C.tuple_bool
|
||||
}},
|
||||
{kObjectTypeList,
|
||||
{
|
||||
{"__len__", prim::kPrimListLen}, // P.list_len,
|
||||
{"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
|
||||
{"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
|
||||
{"__ms_iter__", prim::kPrimIdentity}, // P.identity
|
||||
{"__ms_next__", std::string("list_next")}, // C.list_next
|
||||
{"append", std::string("list_append")}, // C.list_next
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
{"insert", std::string("list_insert")},
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
|
||||
{"items", prim::kPrimDictItems}, // P.dict_items
|
||||
{"__bool__", std::string("dict_bool")} // C.dict_bool
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"all", std::string("all_")}, // C.reduce_all
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
{"__sub__", std::string("sub")}, // C.sub
|
||||
{"__mul__", std::string("mul")}, // C.mul
|
||||
{"abs", std::string("abs_")}, // C.abs_
|
||||
{"mean", std::string("mean")}, // C.mean
|
||||
{"__truediv__", std::string("truediv")}, // C.truediv
|
||||
{"__floordiv__", std::string("floordiv")}, // C.floordiv
|
||||
{"__mod__", std::string("mod")}, // C.mod
|
||||
{"__pow__", std::string("pow_")}, // C.pow
|
||||
{"__floor__", std::string("array_floor")}, // C.array_floor
|
||||
{"__trunc__", std::string("array_trunc")}, // C.array_trunc
|
||||
{"__pos__", std::string("array_uadd")}, // C.array_uadd
|
||||
{"__neg__", std::string("array_usub")}, // C.array_usub
|
||||
{"__eq__", std::string("eq")}, // C.eq
|
||||
{"__ne__", std::string("ne")}, // C.ne
|
||||
{"__lt__", std::string("lt")}, // C.lt
|
||||
{"__gt__", std::string("gt")}, // C.gt
|
||||
{"__le__", std::string("le")}, // C.le
|
||||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
|
||||
{"view", std::string("view")}, // C.view
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
{"__ms_iter__", std::string("array_iter")}, // C.array_iter
|
||||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", std::string("item")}, // P.item,
|
||||
{"itemset", std::string("itemset")}, // P.itemset,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"flatten", std::string("flatten")}, // P.reshape(,-1)
|
||||
{"reshape", std::string("reshape")}, // P.reshape()
|
||||
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
|
||||
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
|
||||
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
|
||||
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
|
||||
{"swapaxes", std::string("swapaxes")}, // P.transpose()
|
||||
{"narrow", std::string("narrow")}, // narrow()
|
||||
{"masked_fill", std::string("masked_fill")}, // masked_fill()
|
||||
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"cumsum", std::string("cumsum")}, // P.cumsum()
|
||||
{"copy", std::string("copy")}, // copy()
|
||||
{"max", std::string("max")}, // P.reduce_max()
|
||||
{"min", std::string("min")}, // P.reduce_min()
|
||||
{"fill", std::string("fill")}, // P.fill()
|
||||
{"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
|
||||
{"clip", std::string("clip")}, // P.maximum(P.minimum)
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
{"argmax", std::string("argmax")}, // P.Argmax()
|
||||
{"argmin", std::string("argmin")}, // P.Argmax()
|
||||
{"resize", std::string("resize")}, // P.Reshape()
|
||||
{"choose", std::string("choose")}, // P.Select()
|
||||
{"diagonal", std::string("diagonal")}, // P.Eye()
|
||||
{"searchsorted", std::string("searchsorted")}, // P.Select()
|
||||
{"take", std::string("take")}, // P.GatherNd()
|
||||
{"tensor_scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd()
|
||||
{"trace", std::string("trace")}, // P.Eye()
|
||||
{"var", std::string("var")}, // P.ReduceSum
|
||||
{"std", std::string("std")}, // P.ReduceSum
|
||||
{"sum", std::string("sum")}, // P.ReduceSum
|
||||
{"repeat", std::string("repeat")}, // C.repeat_elements
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
{"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
|
||||
}},
|
||||
{kObjectTypeCSRTensorType,
|
||||
{
|
||||
{"astype", std::string("csr_astype")}, // C.csr_astype
|
||||
{"abs", std::string("csr_abs")}, // C.csr_abs
|
||||
{"sum", std::string("csr_sum")}, // C.csr_sum
|
||||
{"mv", std::string("csr_mv")}, // C.csr_mv
|
||||
{"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
|
||||
{"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
|
||||
{"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
|
||||
}},
|
||||
{kObjectTypeCOOTensorType,
|
||||
{
|
||||
{"astype", std::string("coo_astype")}, // C.coo_astype
|
||||
{"abs", std::string("coo_abs")}, // C.coo_abs
|
||||
{"to_tuple", std::string("coo_to_tuple")}, // C.coo_to_tuple
|
||||
{"to_csr", std::string("coo_to_csr")}, // C.coo_to_csr
|
||||
{"to_dense", std::string("coo_to_dense")}, // C.coo_to_dense
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
return method_map;
|
||||
}
|
||||
|
||||
|
|
|
@ -1560,6 +1560,15 @@ def while_cond(x):
|
|||
return x
|
||||
|
||||
|
||||
def tensor_scatter_add(x, indices, updates):
|
||||
"""
|
||||
Creates a new tensor by adding the values from the positions in `x` indicated by
|
||||
`indices`, with values from `updates`. When multiple values are given for the same
|
||||
index, the updated result will be the sum of all values.
|
||||
"""
|
||||
return F.tensor_scatter_add(x, indices, updates)
|
||||
|
||||
|
||||
def coo_to_csr(x):
|
||||
"""convert coo to csr."""
|
||||
row_indices = x.indices[:, 0]
|
||||
|
|
|
@ -1415,6 +1415,52 @@ class Tensor(Tensor_):
|
|||
return reduce_(self, reduce_min(keepdims), cmp_fn=minimum(), axis=axis, keepdims=keepdims,
|
||||
initial=initial, where=where)
|
||||
|
||||
def tensor_scatter_add(self, indices, updates):
|
||||
"""
|
||||
Creates a new tensor by adding the values from the positions in self tensor indicated by
|
||||
`indices`, with values from `updates`. When multiple values are given for the same
|
||||
index, the updated result will be the sum of all values. This operation is almost
|
||||
equivalent to using ScatterNdAdd, except that the updates are applied on output `Tensor`
|
||||
instead of input `Parameter`.
|
||||
|
||||
The last axis of `indices` is the depth of each index vectors. For each index vector,
|
||||
there must be a corresponding value in `updates`. The shape of `updates` should be
|
||||
equal to the shape of `self[indices]`. For more details, see use cases.
|
||||
|
||||
Note:
|
||||
If some values of the `indices` are out of bound, instead of raising an index error,
|
||||
the corresponding `updates` will not be updated to self tensor.
|
||||
|
||||
Args:
|
||||
indices (Tensor): The index of input tensor whose data type is int32 or int64.
|
||||
The rank must be at least 2.
|
||||
updates (Tensor): The tensor to update the input tensor, has the same type as self tensor,
|
||||
and updates. Shape should be equal to indices.shape[:-1] + self.shape[indices.shape[-1]:].
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and type as self tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
ValueError: If length of shape of self tensor is less than the last dimension of shape of `indices`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype('float32'))
|
||||
>>> indices = Tensor(np.array([[0, 0], [0, 0]]).astype('int32'))
|
||||
>>> updates = Tensor(np.array([1.0, 2.2]).astype('float32'))
|
||||
>>> output = x.tensor_scatter_add(indices, updates)
|
||||
>>> print(output)
|
||||
[[ 3.1 0.3 3.6]
|
||||
[ 0.4 0.5 -3.2]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('tensor_scatter_add')()(self, indices, updates)
|
||||
|
||||
def fill(self, value):
|
||||
"""
|
||||
Fill the tensor with a scalar value.
|
||||
|
|
|
@ -23,7 +23,7 @@ from . import array_func, parameter_func, math_func
|
|||
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
|
||||
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
|
||||
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
|
||||
scatter_max, scatter_min)
|
||||
tensor_scatter_add, scatter_max, scatter_min)
|
||||
from .parameter_func import assign, assign_add, assign_sub, index_add
|
||||
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
|
||||
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
|
||||
|
|
|
@ -948,6 +948,57 @@ def gather_nd(input_x, indices):
|
|||
return gather_nd_(input_x, indices)
|
||||
|
||||
|
||||
tensor_scatter_add_ = P.TensorScatterAdd()
|
||||
def tensor_scatter_add(input_x, indices, updates):
|
||||
"""
|
||||
Creates a new tensor by adding the values from the positions in `input_x` indicated by
|
||||
`indices`, with values from `updates`. When multiple values are given for the same
|
||||
index, the updated result will be the sum of all values. This operation is almost
|
||||
equivalent to using ScatterNdAdd, except that the updates are applied on output `Tensor`
|
||||
instead of input `Parameter`.
|
||||
|
||||
The last axis of `indices` is the depth of each index vectors. For each index vector,
|
||||
there must be a corresponding value in `updates`. The shape of `updates` should be
|
||||
equal to the shape of `input_x[indices]`. For more details, see use cases.
|
||||
|
||||
Note:
|
||||
If some values of the `indices` are out of bound, instead of raising an index error,
|
||||
the corresponding `updates` will not be updated to `input_x`.
|
||||
|
||||
Args:
|
||||
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
||||
The rank must be at least 2.
|
||||
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and updates. Shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor, nn
|
||||
>>> from mindspore import ops
|
||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> output = ops.tensor_scatter_add(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 3.1 0.3 3.6]
|
||||
[ 0.4 0.5 -3.2]]
|
||||
"""
|
||||
|
||||
return tensor_scatter_add_(input_x, indices, updates)
|
||||
|
||||
|
||||
##############################
|
||||
# Type Conversion Functions.
|
||||
##############################
|
||||
|
@ -1132,6 +1183,7 @@ __all__ = [
|
|||
'expand_dims',
|
||||
'transpose',
|
||||
'scatter_nd',
|
||||
'tensor_scatter_add',
|
||||
'gather',
|
||||
'gather_d',
|
||||
'gather_nd',
|
||||
|
|
|
@ -71,7 +71,6 @@ scatter_update = P.ScatterUpdate()
|
|||
tensor_scatter_update = P.TensorScatterUpdate()
|
||||
tensor_scatter_min = P.TensorScatterMin()
|
||||
tensor_scatter_max = P.TensorScatterMax()
|
||||
tensor_scatter_add = P.TensorScatterAdd()
|
||||
tensor_scatter_sub = P.TensorScatterSub()
|
||||
tensor_scatter_mul = P.TensorScatterMul()
|
||||
tensor_scatter_div = P.TensorScatterDiv()
|
||||
|
@ -900,7 +899,6 @@ logical_or = P.LogicalOr()
|
|||
logical_not = P.LogicalNot()
|
||||
cumsum = P.CumSum()
|
||||
cumprod = P.CumProd()
|
||||
tensor_scatter_add = P.TensorScatterAdd()
|
||||
array_to_scalar = Primitive('array_to_scalar')
|
||||
is_ = Primitive("is_")
|
||||
is_not = Primitive("is_not")
|
||||
|
@ -1022,6 +1020,6 @@ tensor_operator_registry.register('narrow', narrow)
|
|||
tensor_operator_registry.register('sort', sort)
|
||||
tensor_operator_registry.register('zeros', zeros)
|
||||
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
|
||||
tensor_operator_registry.register('tensor_scatter_add', tensor_scatter_add)
|
||||
tensor_operator_registry.register('tensor_scatter_add', P.TensorScatterAdd)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -6951,30 +6951,10 @@ class TensorScatterAdd(_TensorScatterOp):
|
|||
equivalent to using ScatterNdAdd, except that the updates are applied on output `Tensor`
|
||||
instead of input `Parameter`.
|
||||
|
||||
The last axis of `indices` is the depth of each index vectors. For each index vector,
|
||||
there must be a corresponding value in `updates`. The shape of `updates` should be
|
||||
equal to the shape of `input_x[indices]`. For more details, see use cases.
|
||||
|
||||
Note:
|
||||
If some values of the `indices` are out of bound, instead of raising an index error,
|
||||
the corresponding `updates` will not be updated to `input_x`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
||||
The rank must be at least 2.
|
||||
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
||||
Refer to :func:`mindspore.ops.tensor_scatter_add` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
|
|
Loading…
Reference in New Issue