fix tensor_scatter_add param error and add ut/st cases

This commit is contained in:
yuchaojie 2022-05-06 16:35:03 +08:00
parent 02f58183c5
commit 8d172e4973
3 changed files with 6 additions and 3 deletions

View File

@ -690,6 +690,9 @@ mindspore.Tensor
根据指定的更新值和输入索引通过相加运算更新本Tensor的值。当同一索引有不同值时更新的结果将是所有值的总和。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新到 `input_x` ,而不是抛出索引错误。
**参数:**
- **indices** (Tensor) - Tensor的索引数据类型为int32或int64的。其rank必须至少为2。

View File

@ -8,7 +8,7 @@
`indices` 的最后一个轴是每个索引向量的深度。对于每个索引向量, `updates` 中必须有相应的值。 `updates` 的shape应该等于 `input_x[indices]` 的shape。有关更多详细信息请参见使用用例。
.. note::
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新 `input_x` ,而不是抛出索引错误。
如果 `indices` 的某些值超出范围,则相应的 `updates` 不会更新 `input_x` ,而不是抛出索引错误。
**参数:**

View File

@ -1506,7 +1506,7 @@ class Tensor(Tensor_):
[ 0.4 0.5 -3.2]]
"""
self._init_check()
return tensor_operator_registry.get('tensor_scatter_add')()(self, indices, updates)
return tensor_operator_registry.get("tensor_scatter_add")()(self, indices, updates)
def fill(self, value):
"""
@ -2726,7 +2726,7 @@ class COOTensor(COOTensor_):
``GPU``
"""
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
return tensor_operator_registry.get("tensor_scatter_add")(
return tensor_operator_registry.get("tensor_scatter_add")()(
zeros_tensor, self.indices, self.values)
@property