solve the bug of res error when axis or shift is a number

This commit is contained in:
shenwei41 2023-01-30 16:19:31 +08:00
parent a3d5d94c4b
commit c53a167834
3 changed files with 12 additions and 4 deletions

View File

@ -8,11 +8,11 @@ mindspore.ops.roll
参数:
- **x** (Tensor) - 输入Tensor。
- **shifts** (Union[list(int), tuple(int), int]) - 指定元素移动方式,如果为正数,则元素沿指定维度正向移动(朝向较大的索引)的位置数。负偏移将向相反的方向滚动元素。
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值None。
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值None。如果 `dims` 为None则会将输入Tensor展平后再进行roll计算然后将计算结果reshape为输入的shape。
返回:
Tensorshape和数据类型与输入 `x` 相同。
异常:
- **TypeError** - `shifts` 不是int、tuple或list。
- **TypeError** - `dims` 不是int、tuple或list。
- **TypeError** - `dims` 不是int、tuple或list。

View File

@ -6133,7 +6133,8 @@ def roll(x, shifts, dims=None):
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
in the opposite direction.
dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled.
Default: None.
Default: None. If `dims` is None, the Tensor will be flattened before rolling and then restored to the
original shape.
Returns:
Tensor, has the same shape and type as `x`.
@ -6155,7 +6156,10 @@ def roll(x, shifts, dims=None):
>>> print(output)
[3. 4. 0. 1. 2.]
"""
dims = dims if dims is not None else 0
_shape = x.shape
if dims is None:
flatten_x = x.reshape(-1)
return Roll(shifts, 0)(flatten_x).reshape(_shape)
return Roll(shifts, dims)(x)

View File

@ -7797,7 +7797,11 @@ class Roll(Primitive):
"""Initialize Roll"""
if context.get_context("device_target") == "GPU":
validator.check_value_type("shift", shift, [int, tuple, list], self.name)
if not isinstance(shift, (list, tuple)):
self.add_prim_attr('shift', [shift])
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
if not isinstance(axis, (list, tuple)):
self.add_prim_attr('axis', [axis])
else:
if isinstance(shift, (tuple, list)) and isinstance(axis, (tuple, list)):
validator.check_equal_int(len(shift), 1, "shift size", self.name)