!47920 fix roll function when dims is None

Merge pull request !47920 from 冯一航/fix_roll_function_dims_none
This commit is contained in:
i-robot 2023-01-17 01:45:59 +00:00 committed by Gitee
commit a023825aae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 7 additions and 3 deletions

View File

@ -8,7 +8,7 @@ mindspore.ops.roll
参数:
- **x** (Tensor) - 输入Tensor。
- **shifts** (Union[list(int), tuple(int), int]) - 指定元素移动方式,如果为正数,则元素沿指定维度正向移动(朝向较大的索引)的位置数。负偏移将向相反的方向滚动元素。
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值None。
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值None。如果dims为None则会将输入Tensor展平后再进行roll计算然后将计算结果reshape为输入的shape。
返回:
Tensorshape和数据类型与输入 `x` 相同。

View File

@ -8360,7 +8360,8 @@ def roll(x, shifts, dims=None):
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
in the opposite direction.
dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled.
Default: None.
Default: None. If dims is None, the Tensor will be flattened before rolling and then restored to the
original shape.
Returns:
Tensor, has the same shape and type as `x`.
@ -8382,7 +8383,10 @@ def roll(x, shifts, dims=None):
>>> print(output)
[3. 4. 0. 1. 2.]
"""
dims = dims if dims is not None else 0
_shape = x.shape
if dims is None:
flatten_x = x.reshape(-1)
return Roll(shifts, 0)(flatten_x).reshape(_shape)
return Roll(shifts, dims)(x)