forked from mindspore-Ecosystem/mindspore
!33951 fix tensor_scatter_add param error
Merge pull request !33951 from yuchaojie/op_dev
This commit is contained in:
commit
4899e19a36
|
@ -755,6 +755,9 @@ mindspore.Tensor
|
|||
|
||||
根据指定的更新值和输入索引,通过相加运算更新本Tensor的值。当同一索引有不同值时,更新的结果将是所有值的总和。
|
||||
|
||||
.. note::
|
||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,而不是抛出索引错误。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **indices** (Tensor) - Tensor的索引,数据类型为int32或int64的。其rank必须至少为2。
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息,请参见使用用例。
|
||||
|
||||
.. note::
|
||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新为 `input_x` ,而不是抛出索引错误。
|
||||
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,而不是抛出索引错误。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
|
|
@ -1506,7 +1506,7 @@ class Tensor(Tensor_):
|
|||
[ 0.4 0.5 -3.2]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('tensor_scatter_add')()(self, indices, updates)
|
||||
return tensor_operator_registry.get("tensor_scatter_add")()(self, indices, updates)
|
||||
|
||||
def fill(self, value):
|
||||
"""
|
||||
|
@ -2768,7 +2768,7 @@ class COOTensor(COOTensor_):
|
|||
``GPU``
|
||||
"""
|
||||
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
|
||||
return tensor_operator_registry.get("tensor_scatter_add")(
|
||||
return tensor_operator_registry.get("tensor_scatter_add")()(
|
||||
zeros_tensor, self.indices, self.values)
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue