forked from mindspore-Ecosystem/mindspore
!42017 [ST][MS][OPS] Update accumulate_n Functional API and ST.
Merge pull request !42017 from alashkari/update_accumulate_n
This commit is contained in:
commit
466cb6c7ee
|
@ -5777,20 +5777,18 @@ def remainder(x, y):
|
|||
return out
|
||||
|
||||
|
||||
def accumulate_n(*x):
|
||||
def accumulate_n(x):
|
||||
r"""
|
||||
Computes accumulation of all input tensors element-wise.
|
||||
|
||||
AccumulateNV2 is similar to AddN, but there is a significant difference
|
||||
among them: AccumulateNV2 will not wait for all of its inputs to be ready
|
||||
before summing. That is to say, AccumulateNV2 is able to save
|
||||
memory when inputs are ready at different time since the minimum temporary
|
||||
storage is proportional to the output size rather than the input size.
|
||||
AccumulateNV2 is similar to AddN, but there is a significant difference among them: AccumulateNV2 will not wait
|
||||
for all of its inputs to be ready before summing. That is to say, AccumulateNV2 is able to save memory when inputs
|
||||
are ready at different time since the minimum temporary storage is proportional to the output size rather than the
|
||||
input size.
|
||||
|
||||
Args:
|
||||
- **x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
|
||||
is made up of multiple tensors whose dtype is number to be added together.
|
||||
Each element of tuple or list should have the same shape.
|
||||
x (Union(tuple[Tensor], list[Tensor])): The input tuple or list is made up of multiple tensors whose dtype is
|
||||
number to be added together. Each element of tuple or list should have the same shape.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as each entry of the `x`.
|
||||
|
@ -5805,7 +5803,7 @@ def accumulate_n(*x):
|
|||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> y = Tensor(np.array([4, 5, 6]), mindspore.float32)
|
||||
>>> output = ops.accumulate_n(x, y, x, y)
|
||||
>>> output = ops.accumulate_n([x, y, x, y])
|
||||
>>> print(output)
|
||||
[10. 14. 18.]
|
||||
"""
|
||||
|
|
|
@ -26,7 +26,7 @@ def accumulate_n_forward_functional(nptype):
|
|||
input_x = Tensor(np.array([1, 2, 3]).astype(nptype))
|
||||
input_y = Tensor(np.array([4, 5, 6]).astype(nptype))
|
||||
|
||||
output = F.accumulate_n(input_x, input_y, input_x, input_y)
|
||||
output = F.accumulate_n([input_x, input_y, input_x, input_y])
|
||||
expected = np.array([10., 14., 18.])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
|
Loading…
Reference in New Issue