fix roll function dims none

This commit is contained in:
fengyihang 2023-01-16 16:58:53 +08:00
parent 2ff013c3cc
commit 8b30f7e094
2 changed files with 7 additions and 3 deletions

View File

@ -8,7 +8,7 @@ mindspore.ops.roll
参数:
- **x** (Tensor) - 输入Tensor。
- **shifts** (Union[list(int), tuple(int), int]) - 指定元素移动方式,如果为正数,则元素沿指定维度正向移动(朝向较大的索引)的位置数。负偏移将向相反的方向滚动元素。
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值None。
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值None。如果dims为None则会将输入Tensor展平后再进行roll计算然后将计算结果reshape为输入的shape。
返回:
Tensorshape和数据类型与输入 `x` 相同。

View File

@ -8109,7 +8109,8 @@ def roll(x, shifts, dims=None):
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
in the opposite direction.
dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled.
Default: None.
Default: None. If dims is None, the Tensor will be flattened before rolling and then restored to the
original shape.
Returns:
Tensor, has the same shape and type as `x`.
@ -8131,7 +8132,10 @@ def roll(x, shifts, dims=None):
>>> print(output)
[3. 4. 0. 1. 2.]
"""
dims = dims if dims is not None else 0
_shape = x.shape
if dims is None:
flatten_x = x.reshape(-1)
return Roll(shifts, 0)(flatten_x).reshape(_shape)
return Roll(shifts, dims)(x)