diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index a0df517b6d4..c3b1a2de813 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -5777,20 +5777,18 @@ def remainder(x, y): return out -def accumulate_n(*x): +def accumulate_n(x): r""" Computes accumulation of all input tensors element-wise. - AccumulateNV2 is similar to AddN, but there is a significant difference - among them: AccumulateNV2 will not wait for all of its inputs to be ready - before summing. That is to say, AccumulateNV2 is able to save - memory when inputs are ready at different time since the minimum temporary - storage is proportional to the output size rather than the input size. + AccumulateNV2 is similar to AddN, but there is a significant difference among them: AccumulateNV2 will not wait + for all of its inputs to be ready before summing. That is to say, AccumulateNV2 is able to save memory when inputs + are ready at different time since the minimum temporary storage is proportional to the output size rather than the + input size. Args: - - **x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list - is made up of multiple tensors whose dtype is number to be added together. - Each element of tuple or list should have the same shape. + x (Union(tuple[Tensor], list[Tensor])): The input tuple or list is made up of multiple tensors whose dtype is + number to be added together. Each element of tuple or list should have the same shape. Returns: Tensor, has the same shape and dtype as each entry of the `x`. @@ -5805,7 +5803,7 @@ def accumulate_n(*x): Examples: >>> x = Tensor(np.array([1, 2, 3]), mindspore.float32) >>> y = Tensor(np.array([4, 5, 6]), mindspore.float32) - >>> output = ops.accumulate_n(x, y, x, y) + >>> output = ops.accumulate_n([x, y, x, y]) >>> print(output) [10. 14. 18.] """ diff --git a/tests/st/ops/ascend/test_accumulate_nv2.py b/tests/st/ops/ascend/test_accumulate_nv2.py index a446cb8f6e9..bcd47c1f967 100644 --- a/tests/st/ops/ascend/test_accumulate_nv2.py +++ b/tests/st/ops/ascend/test_accumulate_nv2.py @@ -26,7 +26,7 @@ def accumulate_n_forward_functional(nptype): input_x = Tensor(np.array([1, 2, 3]).astype(nptype)) input_y = Tensor(np.array([4, 5, 6]).astype(nptype)) - output = F.accumulate_n(input_x, input_y, input_x, input_y) + output = F.accumulate_n([input_x, input_y, input_x, input_y]) expected = np.array([10., 14., 18.]) np.testing.assert_array_almost_equal(output.asnumpy(), expected)