!33951 fix tensor_scatter_add param error

Merge pull request !33951 from yuchaojie/op_dev
This commit is contained in:
i-robot 2022-05-07 09:30:10 +00:00 committed by Gitee
commit 4899e19a36
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 6 additions and 3 deletions

View File

@ -755,6 +755,9 @@ mindspore.Tensor
根据指定的更新值和输入索引通过相加运算更新本Tensor的值。当同一索引有不同值时更新的结果将是所有值的总和。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,而不是抛出索引错误。
**参数:**
- **indices** (Tensor) - Tensor的索引数据类型为int32或int64的。其rank必须至少为2。

View File

@ -8,7 +8,7 @@
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息请参见使用用例。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新 `input_x` ,而不是抛出索引错误。
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新 `input_x` ,而不是抛出索引错误。
**参数:**

View File

@ -1506,7 +1506,7 @@ class Tensor(Tensor_):
[ 0.4 0.5 -3.2]]
"""
self._init_check()
return tensor_operator_registry.get('tensor_scatter_add')()(self, indices, updates)
return tensor_operator_registry.get("tensor_scatter_add")()(self, indices, updates)
def fill(self, value):
"""
@ -2768,7 +2768,7 @@ class COOTensor(COOTensor_):
``GPU``
"""
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
return tensor_operator_registry.get("tensor_scatter_add")(
return tensor_operator_registry.get("tensor_scatter_add")()(
zeros_tensor, self.indices, self.values)
@property