forked from mindspore-Ecosystem/mindspore
fix roll function dims none
This commit is contained in:
parent
2ff013c3cc
commit
8b30f7e094
|
@ -8,7 +8,7 @@ mindspore.ops.roll
|
|||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。
|
||||
- **shifts** (Union[list(int), tuple(int), int]) - 指定元素移动方式,如果为正数,则元素沿指定维度正向移动(朝向较大的索引)的位置数。负偏移将向相反的方向滚动元素。
|
||||
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值:None。
|
||||
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值:None。如果dims为None,则会将输入Tensor展平后再进行roll计算,然后将计算结果reshape为输入的shape。
|
||||
|
||||
返回:
|
||||
Tensor,shape和数据类型与输入 `x` 相同。
|
||||
|
|
|
@ -8109,7 +8109,8 @@ def roll(x, shifts, dims=None):
|
|||
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
|
||||
in the opposite direction.
|
||||
dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled.
|
||||
Default: None.
|
||||
Default: None. If dims is None, the Tensor will be flattened before rolling and then restored to the
|
||||
original shape.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and type as `x`.
|
||||
|
@ -8131,7 +8132,10 @@ def roll(x, shifts, dims=None):
|
|||
>>> print(output)
|
||||
[3. 4. 0. 1. 2.]
|
||||
"""
|
||||
dims = dims if dims is not None else 0
|
||||
_shape = x.shape
|
||||
if dims is None:
|
||||
flatten_x = x.reshape(-1)
|
||||
return Roll(shifts, 0)(flatten_x).reshape(_shape)
|
||||
return Roll(shifts, dims)(x)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue