!35287 Add tensor&functional interface for TensorScatterMax ops.
Merge pull request !35287 from wuweikang/tensor-scatter-max
This commit is contained in:
commit
55da1a87d5
|
@ -356,8 +356,9 @@ Array操作
|
||||||
mindspore.ops.space_to_batch_nd
|
mindspore.ops.space_to_batch_nd
|
||||||
mindspore.ops.split
|
mindspore.ops.split
|
||||||
mindspore.ops.tensor_scatter_add
|
mindspore.ops.tensor_scatter_add
|
||||||
mindspore.ops.tensor_scatter_min
|
|
||||||
mindspore.ops.tensor_scatter_div
|
mindspore.ops.tensor_scatter_div
|
||||||
|
mindspore.ops.tensor_scatter_max
|
||||||
|
mindspore.ops.tensor_scatter_min
|
||||||
mindspore.ops.tensor_scatter_mul
|
mindspore.ops.tensor_scatter_mul
|
||||||
mindspore.ops.tensor_scatter_sub
|
mindspore.ops.tensor_scatter_sub
|
||||||
mindspore.ops.tensor_scatter_elements
|
mindspore.ops.tensor_scatter_elements
|
||||||
|
|
|
@ -1392,6 +1392,29 @@ mindspore.Tensor
|
||||||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||||
- **ValueError** - Tensor的shape长度小于 `indices` 的shape的最后一个维度。
|
- **ValueError** - Tensor的shape长度小于 `indices` 的shape的最后一个维度。
|
||||||
|
|
||||||
|
.. py:method:: scatter_max(indices, updates)
|
||||||
|
|
||||||
|
根据指定的更新值和输入索引,通过最大值运算,输出结果以Tensor形式返回。
|
||||||
|
|
||||||
|
索引的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息,请参见下方样例。
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
如果 `indices` 的某些值超出范围,则不会更新相应的 `updates`,同时也不会抛出索引错误。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **indices** (Tensor) - Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||||
|
- **updates** (Tensor) - 指定与本Tensor相减操作的Tensor,其数据类型与该Tensor相同。updates.shape应等于indices.shape[:-1] + self.shape[indices.shape[-1]:]。
|
||||||
|
|
||||||
|
**返回:**
|
||||||
|
|
||||||
|
Tensor,shape和数据类型与原Tensor相同。
|
||||||
|
|
||||||
|
**异常:**
|
||||||
|
|
||||||
|
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||||
|
- **ValueError** - Tensor的shape长度小于 `indices` 的shape的最后一个维度。
|
||||||
|
|
||||||
.. py:method:: scatter_mul(indices, updates)
|
.. py:method:: scatter_mul(indices, updates)
|
||||||
|
|
||||||
根据指定的索引,通过乘法进行计算,将输出赋值到输出Tensor中。
|
根据指定的索引,通过乘法进行计算,将输出赋值到输出Tensor中。
|
||||||
|
|
|
@ -3,24 +3,5 @@
|
||||||
|
|
||||||
.. py:class:: mindspore.ops.TensorScatterMax
|
.. py:class:: mindspore.ops.TensorScatterMax
|
||||||
|
|
||||||
根据指定的更新值和输入索引,通过最大值运算更新输入Tensor的值。
|
根据指定的更新值和输入索引,通过最大值运算将结果赋值到输出Tensor中。
|
||||||
|
更多参考详见 :func:`mindspore.ops.tensor_scatter_min`。
|
||||||
索引的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于input_x[indices]的shape。有关更多详细信息,请参见使用用例。
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
|
|
||||||
|
|
||||||
**输入:**
|
|
||||||
|
|
||||||
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
|
|
||||||
- **indices** (Tensor) - 输入Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
|
||||||
- **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor,其数据类型与输入相同。updates.shape应等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
|
|
||||||
|
|
||||||
**输出:**
|
|
||||||
|
|
||||||
Tensor,shape和数据类型与输入 `input_x` 相同。
|
|
||||||
|
|
||||||
**异常:**
|
|
||||||
|
|
||||||
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
|
||||||
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
maxdspore.ops.tensor_scatter_max
|
||||||
|
================================
|
||||||
|
|
||||||
|
.. py:function:: maxdspore.ops.tensor_scatter_max(input_x, indices, updates)
|
||||||
|
|
||||||
|
根据指定的更新值和输入索引,通过最大值运算,输出结果以Tensor形式返回。
|
||||||
|
|
||||||
|
索引的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息,请参见下方样例。
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
如果 `indices` 的某些值超出范围,则 `input_x` 不会更新相应的 `updates`,同时也不会抛出索引错误。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **input_x** (Tensor) - 输入Tensor。 `input_x` 的维度必须不小于indices.shape[-1]。
|
||||||
|
- **indices** (Tensor) - 输入Tensor的索引,数据类型为int32或int64。其rank必须至少为2。
|
||||||
|
- **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor,其数据类型与输入相同。updates.shape应该等于indices.shape[:-1] + input_x.shape[indices.shape[-1]:]。
|
||||||
|
|
||||||
|
**返回:**
|
||||||
|
|
||||||
|
Tensor,shape和数据类型与输入 `input_x` 相同。
|
||||||
|
|
||||||
|
**异常:**
|
||||||
|
|
||||||
|
- **TypeError** - `indices` 的数据类型既不是int32,也不是int64。
|
||||||
|
- **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。
|
|
@ -355,6 +355,7 @@ Array Operation
|
||||||
mindspore.ops.split
|
mindspore.ops.split
|
||||||
mindspore.ops.tensor_scatter_add
|
mindspore.ops.tensor_scatter_add
|
||||||
mindspore.ops.tensor_scatter_min
|
mindspore.ops.tensor_scatter_min
|
||||||
|
mindspore.ops.tensor_scatter_max
|
||||||
mindspore.ops.tensor_scatter_div
|
mindspore.ops.tensor_scatter_div
|
||||||
mindspore.ops.tensor_scatter_mul
|
mindspore.ops.tensor_scatter_mul
|
||||||
mindspore.ops.tensor_scatter_sub
|
mindspore.ops.tensor_scatter_sub
|
||||||
|
|
|
@ -235,6 +235,7 @@ BuiltInTypeMap &GetMethodMap() {
|
||||||
{"scatter_mul", std::string("tensor_scatter_mul")}, // tensor_scatter_mul()
|
{"scatter_mul", std::string("tensor_scatter_mul")}, // tensor_scatter_mul()
|
||||||
{"scatter_sub", std::string("tensor_scatter_sub")}, // P.TensorScatterSub()
|
{"scatter_sub", std::string("tensor_scatter_sub")}, // P.TensorScatterSub()
|
||||||
{"scatter_min", std::string("tensor_scatter_min")}, // P.TensorScatterMin()
|
{"scatter_min", std::string("tensor_scatter_min")}, // P.TensorScatterMin()
|
||||||
|
{"scatter_max", std::string("tensor_scatter_max")}, // P.TensorScatterMax()
|
||||||
{"scatter_div", std::string("tensor_scatter_div")}, // P.TensorScatterDiv()
|
{"scatter_div", std::string("tensor_scatter_div")}, // P.TensorScatterDiv()
|
||||||
{"norm", std::string("norm")}, // norm()
|
{"norm", std::string("norm")}, // norm()
|
||||||
{"unsorted_segment_min", std::string("unsorted_segment_min")}, // P.UnsortedSegmentMin()
|
{"unsorted_segment_min", std::string("unsorted_segment_min")}, // P.UnsortedSegmentMin()
|
||||||
|
|
|
@ -107,5 +107,21 @@ const BaseRef TensorScatterMinFission::DefinePattern() const {
|
||||||
VarPtr updates = std::make_shared<Var>();
|
VarPtr updates = std::make_shared<Var>();
|
||||||
return VectorRef({prim::kPrimTensorScatterMin, input, indices, updates});
|
return VectorRef({prim::kPrimTensorScatterMin, input, indices, updates});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr TensorScatterMinFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
auto scatter_nd_node = TensorScatterFission::Process(graph, node, nullptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(scatter_nd_node);
|
||||||
|
common::AnfAlgo::SetNodeAttr("cust_aicpu", MakeValue<std::string>("ScatterNdMin"), scatter_nd_node);
|
||||||
|
return scatter_nd_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr TensorScatterMaxFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
auto scatter_nd_node = TensorScatterFission::Process(graph, node, nullptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(scatter_nd_node);
|
||||||
|
common::AnfAlgo::SetNodeAttr("cust_aicpu", MakeValue<std::string>("ScatterNdMax"), scatter_nd_node);
|
||||||
|
return scatter_nd_node;
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -74,6 +74,7 @@ class TensorScatterMaxFission : public TensorScatterFission {
|
||||||
: TensorScatterFission(multigraph, name) {}
|
: TensorScatterFission(multigraph, name) {}
|
||||||
~TensorScatterMaxFission() override = default;
|
~TensorScatterMaxFission() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ValueNodePtr GetScatterNdPrimNode() const override;
|
ValueNodePtr GetScatterNdPrimNode() const override;
|
||||||
|
@ -85,6 +86,7 @@ class TensorScatterMinFission : public TensorScatterFission {
|
||||||
: TensorScatterFission(multigraph, name) {}
|
: TensorScatterFission(multigraph, name) {}
|
||||||
~TensorScatterMinFission() override = default;
|
~TensorScatterMinFission() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ValueNodePtr GetScatterNdPrimNode() const override;
|
ValueNodePtr GetScatterNdPrimNode() const override;
|
||||||
|
|
|
@ -1841,6 +1841,14 @@ def tensor_sactter_div(input_x, indices, updates):
|
||||||
return F.tensor_scatter_div(input_x, indices, updates)
|
return F.tensor_scatter_div(input_x, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_scatter_max(x, indices, updates):
|
||||||
|
"""
|
||||||
|
By comparing the value at the position indicated by `indices` in `x` with the value in the `updates`,
|
||||||
|
the value at the index will eventually be equal to the largest one to create a new tensor.
|
||||||
|
"""
|
||||||
|
return F.tensor_scatter_max(x, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
def tensor_scatter_min(x, indices, updates):
|
def tensor_scatter_min(x, indices, updates):
|
||||||
"""
|
"""
|
||||||
By comparing the value at the position indicated by `indices` in `x` with the value in the `updates`,
|
By comparing the value at the position indicated by `indices` in `x` with the value in the `updates`,
|
||||||
|
|
|
@ -2449,6 +2449,57 @@ class Tensor(Tensor_):
|
||||||
self._init_check()
|
self._init_check()
|
||||||
return tensor_operator_registry.get('tensor_scatter_min')()(self, indices, updates)
|
return tensor_operator_registry.get('tensor_scatter_min')()(self, indices, updates)
|
||||||
|
|
||||||
|
def scatter_max(self, indices, updates):
|
||||||
|
"""
|
||||||
|
By comparing the value at the position indicated by `indices` in `x` with the value in the `updates`,
|
||||||
|
the value at the index will eventually be equal to the largest one to create a new tensor.
|
||||||
|
|
||||||
|
The last axis of the index is the depth of each index vector. For each index vector,
|
||||||
|
there must be a corresponding value in `updates`. The shape of `updates` should be
|
||||||
|
equal to the shape of `input_x[indices]`. For more details, see case below.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If some values of the `indices` are out of bound, instead of raising an index error,
|
||||||
|
the corresponding `updates` will not be updated to `input_x`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices (Tensor): The index of input tensor whose data type is int32 or int64.
|
||||||
|
The rank must be at least 2.
|
||||||
|
updates (Tensor): The tensor to update the input tensor, has the same type as input,
|
||||||
|
and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same shape and type as `input_x`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||||
|
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||||
|
>>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
|
||||||
|
>>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||||
|
>>> # Next, demonstrate the approximate operation process of this operator:
|
||||||
|
>>> # 1, indices[0] = [0, 0], indices[1] = [0, 0]
|
||||||
|
>>> # 2, And input_x[0, 0] = -0.1
|
||||||
|
>>> # 3, So input_x[indices] = [-0.1, -0.1]
|
||||||
|
>>> # 4, Satisfy the above formula: input_x[indices].shape=(2) == updates.shape=(2)
|
||||||
|
>>> op = ops.TensorScatterMax()
|
||||||
|
>>> # 5, Perform the max operation for the first time:
|
||||||
|
>>> # first_input_x = Max(input_x[0][0], updates[0]) = [[2.2, 0.3, 3.6], [0.4, 0.5, -3.2]]
|
||||||
|
>>> # 6, Perform the max operation for the second time:
|
||||||
|
>>> # second_input_x = Max(input_x[0][0], updates[0]) = [[2.2, 0.3, 3.6], [0.4, 0.5, -3.2]]
|
||||||
|
>>> output = op(input_x, indices, updates)
|
||||||
|
>>> print(output)
|
||||||
|
[[ 2.2 0.3 3.6]
|
||||||
|
[ 0.4 0.5 -3.2]]
|
||||||
|
"""
|
||||||
|
self._init_check()
|
||||||
|
return tensor_operator_registry.get('tensor_scatter_max')()(self, indices, updates)
|
||||||
|
|
||||||
def fill(self, value):
|
def fill(self, value):
|
||||||
"""
|
"""
|
||||||
Fill the tensor with a scalar value.
|
Fill the tensor with a scalar value.
|
||||||
|
|
|
@ -598,6 +598,7 @@ def get_slice_grad_vmap_rule(prim, axis_size):
|
||||||
@vmap_rules_getters.register(P.TensorScatterSub)
|
@vmap_rules_getters.register(P.TensorScatterSub)
|
||||||
@vmap_rules_getters.register(P.TensorScatterMul)
|
@vmap_rules_getters.register(P.TensorScatterMul)
|
||||||
@vmap_rules_getters.register(P.TensorScatterDiv)
|
@vmap_rules_getters.register(P.TensorScatterDiv)
|
||||||
|
@vmap_rules_getters.register(P.TensorScatterMax)
|
||||||
def get_tensor_scatter_op_vmap_rule(prim, axis_size):
|
def get_tensor_scatter_op_vmap_rule(prim, axis_size):
|
||||||
"""
|
"""
|
||||||
VmapRule for `TensorScatter*` operations, such as `TensorScatterMul`.
|
VmapRule for `TensorScatter*` operations, such as `TensorScatterMul`.
|
||||||
|
@ -608,6 +609,7 @@ def get_tensor_scatter_op_vmap_rule(prim, axis_size):
|
||||||
"TensorScatterSub": P.TensorScatterSub,
|
"TensorScatterSub": P.TensorScatterSub,
|
||||||
"TensorScatterMul": P.TensorScatterMul,
|
"TensorScatterMul": P.TensorScatterMul,
|
||||||
"TensorScatterDiv": P.TensorScatterDiv,
|
"TensorScatterDiv": P.TensorScatterDiv,
|
||||||
|
"TensorScatterMax": P.TensorScatterMax,
|
||||||
}
|
}
|
||||||
if isinstance(prim, str):
|
if isinstance(prim, str):
|
||||||
prim_name = prim
|
prim_name = prim
|
||||||
|
|
|
@ -72,6 +72,7 @@ from .array_func import (
|
||||||
tensor_scatter_mul,
|
tensor_scatter_mul,
|
||||||
unique_consecutive,
|
unique_consecutive,
|
||||||
tensor_scatter_div,
|
tensor_scatter_div,
|
||||||
|
tensor_scatter_max,
|
||||||
tensor_scatter_min,
|
tensor_scatter_min,
|
||||||
tensor_scatter_elements,
|
tensor_scatter_elements,
|
||||||
scatter_add,
|
scatter_add,
|
||||||
|
|
|
@ -67,6 +67,7 @@ tensor_scatter_sub_ = P.TensorScatterSub()
|
||||||
tensor_scatter_mul_ = P.TensorScatterMul()
|
tensor_scatter_mul_ = P.TensorScatterMul()
|
||||||
tensor_scatter_div_ = P.TensorScatterDiv()
|
tensor_scatter_div_ = P.TensorScatterDiv()
|
||||||
tensor_scatter_min_ = P.TensorScatterMin()
|
tensor_scatter_min_ = P.TensorScatterMin()
|
||||||
|
tensor_scatter_max_ = P.TensorScatterMax()
|
||||||
scalar_to_array_ = P.ScalarToArray()
|
scalar_to_array_ = P.ScalarToArray()
|
||||||
scalar_to_tensor_ = P.ScalarToTensor()
|
scalar_to_tensor_ = P.ScalarToTensor()
|
||||||
tuple_to_array_ = P.TupleToArray()
|
tuple_to_array_ = P.TupleToArray()
|
||||||
|
@ -2282,6 +2283,60 @@ def tensor_scatter_sub(input_x, indices, updates):
|
||||||
return tensor_scatter_sub_(input_x, indices, updates)
|
return tensor_scatter_sub_(input_x, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_scatter_max(input_x, indices, updates):
|
||||||
|
"""
|
||||||
|
By comparing the value at the position indicated by `indices` in `x` with the value in the `updates`,
|
||||||
|
the value at the index will eventually be equal to the largest one to create a new tensor.
|
||||||
|
|
||||||
|
The last axis of the index is the depth of each index vector. For each index vector,
|
||||||
|
there must be a corresponding value in `updates`. The shape of `updates` should be
|
||||||
|
equal to the shape of input_x[indices].
|
||||||
|
For more details, see use cases.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If some values of the `indices` are out of bound, instead of raising an index error,
|
||||||
|
the corresponding `updates` will not be updated to `input_x`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_x (Tensor): The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
||||||
|
indices (Tensor): The index of input tensor whose data type is int32 or int64.
|
||||||
|
The rank must be at least 2.
|
||||||
|
updates (Tensor): The tensor to update the input tensor, has the same type as input,
|
||||||
|
and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same shape and type as `input_x`.
|
||||||
|
|
||||||
|
Tensor, has the same shape and type as `input_x`.
|
||||||
|
Raises:
|
||||||
|
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||||
|
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||||
|
>>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
|
||||||
|
>>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||||
|
>>> # Next, demonstrate the approximate operation process of this operator:
|
||||||
|
>>> # 1, indices[0] = [0, 0], indices[1] = [0, 0]
|
||||||
|
>>> # 2, And input_x[0, 0] = -0.1
|
||||||
|
>>> # 3, So input_x[indices] = [-0.1, -0.1]
|
||||||
|
>>> # 4, Satisfy the above formula: input_x[indices].shape=(2) == updates.shape=(2)
|
||||||
|
>>> op = ops.TensorScatterMax()
|
||||||
|
>>> # 5, Perform the max operation for the first time:
|
||||||
|
>>> # first_input_x = Max(input_x[0][0], updates[0]) = [[2.2, 0.3, 3.6], [0.4, 0.5, -3.2]]
|
||||||
|
>>> # 6, Perform the max operation for the second time:
|
||||||
|
>>> # second_input_x = Max(input_x[0][0], updates[0]) = [[2.2, 0.3, 3.6], [0.4, 0.5, -3.2]]
|
||||||
|
>>> output = op(input_x, indices, updates)
|
||||||
|
>>> print(output)
|
||||||
|
[[ 2.2 0.3 3.6]
|
||||||
|
[ 0.4 0.5 -3.2]]
|
||||||
|
"""
|
||||||
|
return tensor_scatter_max_(input_x, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
def tensor_scatter_min(input_x, indices, updates):
|
def tensor_scatter_min(input_x, indices, updates):
|
||||||
"""
|
"""
|
||||||
By comparing the value at the position indicated by `indices` in `input_x` with the value in the `updates`,
|
By comparing the value at the position indicated by `indices` in `input_x` with the value in the `updates`,
|
||||||
|
@ -3641,6 +3696,7 @@ __all__ = [
|
||||||
'tensor_scatter_sub',
|
'tensor_scatter_sub',
|
||||||
'tensor_scatter_mul',
|
'tensor_scatter_mul',
|
||||||
'tensor_scatter_div',
|
'tensor_scatter_div',
|
||||||
|
'tensor_scatter_max',
|
||||||
'tensor_scatter_min',
|
'tensor_scatter_min',
|
||||||
'tensor_scatter_elements',
|
'tensor_scatter_elements',
|
||||||
'unsorted_segment_min',
|
'unsorted_segment_min',
|
||||||
|
|
|
@ -65,7 +65,6 @@ if not security.enable_security():
|
||||||
print_ = P.Print()
|
print_ = P.Print()
|
||||||
squeeze = P.Squeeze()
|
squeeze = P.Squeeze()
|
||||||
tensor_scatter_update = P.TensorScatterUpdate()
|
tensor_scatter_update = P.TensorScatterUpdate()
|
||||||
tensor_scatter_max = P.TensorScatterMax()
|
|
||||||
scatter_nd_update = P.ScatterNdUpdate()
|
scatter_nd_update = P.ScatterNdUpdate()
|
||||||
stack = P.Stack()
|
stack = P.Stack()
|
||||||
|
|
||||||
|
@ -1023,6 +1022,7 @@ tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update
|
||||||
tensor_operator_registry.register('tensor_scatter_mul', tensor_scatter_mul)
|
tensor_operator_registry.register('tensor_scatter_mul', tensor_scatter_mul)
|
||||||
tensor_operator_registry.register('tensor_scatter_div', tensor_scatter_div)
|
tensor_operator_registry.register('tensor_scatter_div', tensor_scatter_div)
|
||||||
tensor_operator_registry.register('tensor_scatter_min', P.TensorScatterMin)
|
tensor_operator_registry.register('tensor_scatter_min', P.TensorScatterMin)
|
||||||
|
tensor_operator_registry.register('tensor_scatter_max', P.TensorScatterMax)
|
||||||
tensor_operator_registry.register('tensor_scatter_sub', P.TensorScatterSub)
|
tensor_operator_registry.register('tensor_scatter_sub', P.TensorScatterSub)
|
||||||
tensor_operator_registry.register('tensor_scatter_add', P.TensorScatterAdd)
|
tensor_operator_registry.register('tensor_scatter_add', P.TensorScatterAdd)
|
||||||
tensor_operator_registry.register('bernoulli', bernoulli)
|
tensor_operator_registry.register('bernoulli', bernoulli)
|
||||||
|
|
|
@ -6665,34 +6665,13 @@ class TensorScatterUpdate(_TensorScatterOp):
|
||||||
|
|
||||||
class TensorScatterMax(_TensorScatterOp):
|
class TensorScatterMax(_TensorScatterOp):
|
||||||
"""
|
"""
|
||||||
By comparing the value at the position indicated by the index in input_x with the value in the update,
|
By comparing the value at the position indicated by `indices` in `x` with the value in the `updates`,
|
||||||
the value at the index will eventually be equal to the largest one to create a new tensor.
|
the value at the index will eventually be equal to the largest one to create a new tensor.
|
||||||
|
|
||||||
The last axis of the index is the depth of each index vector. For each index vector,
|
Refer to :func:`mindspore.ops.tensor_scatter_max` for more detail.
|
||||||
there must be a corresponding value in `updates`. The shape of `updates` should be
|
|
||||||
equal to the shape of input_x[indices].
|
|
||||||
For more details, see use cases.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
If some values of the `indices` are out of bound, instead of raising an index error,
|
|
||||||
the corresponding `updates` will not be updated to `input_x`.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
|
||||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
|
||||||
The rank must be at least 2.
|
|
||||||
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
|
||||||
and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, has the same shape and type as `input_x`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
|
||||||
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||||
|
|
|
@ -3536,6 +3536,12 @@ test_case_other_ops = [
|
||||||
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||||
Tensor(np.ones([2, 5], np.float32) * 99)),
|
Tensor(np.ones([2, 5], np.float32) * 99)),
|
||||||
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
|
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
|
||||||
|
('TensorScatterMax', {
|
||||||
|
'block': P.TensorScatterSub(),
|
||||||
|
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
|
||||||
|
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||||
|
Tensor(np.ones([2, 5], np.float32) * 99)),
|
||||||
|
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
|
||||||
('ScatterMaxUseLocking', {
|
('ScatterMaxUseLocking', {
|
||||||
'block': ScatterMax(use_locking=True),
|
'block': ScatterMax(use_locking=True),
|
||||||
'desc_inputs': (Tensor(np.array([1, 0], np.int32)),
|
'desc_inputs': (Tensor(np.array([1, 0], np.int32)),
|
||||||
|
|
Loading…
Reference in New Issue