solve the bug of res error when axis or shift is a number
This commit is contained in:
parent
a3d5d94c4b
commit
c53a167834
|
@ -8,11 +8,11 @@ mindspore.ops.roll
|
|||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。
|
||||
- **shifts** (Union[list(int), tuple(int), int]) - 指定元素移动方式,如果为正数,则元素沿指定维度正向移动(朝向较大的索引)的位置数。负偏移将向相反的方向滚动元素。
|
||||
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值:None。
|
||||
- **dims** (Union[list(int), tuple(int), int], optional) - 指定需移动维度的轴。默认值:None。如果 `dims` 为None,则会将输入Tensor展平后再进行roll计算,然后将计算结果reshape为输入的shape。
|
||||
|
||||
返回:
|
||||
Tensor,shape和数据类型与输入 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `shifts` 不是int、tuple或list。
|
||||
- **TypeError** - `dims` 不是int、tuple或list。
|
||||
- **TypeError** - `dims` 不是int、tuple或list。
|
||||
|
|
|
@ -6133,7 +6133,8 @@ def roll(x, shifts, dims=None):
|
|||
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
|
||||
in the opposite direction.
|
||||
dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled.
|
||||
Default: None.
|
||||
Default: None. If `dims` is None, the Tensor will be flattened before rolling and then restored to the
|
||||
original shape.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and type as `x`.
|
||||
|
@ -6155,7 +6156,10 @@ def roll(x, shifts, dims=None):
|
|||
>>> print(output)
|
||||
[3. 4. 0. 1. 2.]
|
||||
"""
|
||||
dims = dims if dims is not None else 0
|
||||
_shape = x.shape
|
||||
if dims is None:
|
||||
flatten_x = x.reshape(-1)
|
||||
return Roll(shifts, 0)(flatten_x).reshape(_shape)
|
||||
return Roll(shifts, dims)(x)
|
||||
|
||||
|
||||
|
|
|
@ -7797,7 +7797,11 @@ class Roll(Primitive):
|
|||
"""Initialize Roll"""
|
||||
if context.get_context("device_target") == "GPU":
|
||||
validator.check_value_type("shift", shift, [int, tuple, list], self.name)
|
||||
if not isinstance(shift, (list, tuple)):
|
||||
self.add_prim_attr('shift', [shift])
|
||||
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
|
||||
if not isinstance(axis, (list, tuple)):
|
||||
self.add_prim_attr('axis', [axis])
|
||||
else:
|
||||
if isinstance(shift, (tuple, list)) and isinstance(axis, (tuple, list)):
|
||||
validator.check_equal_int(len(shift), 1, "shift size", self.name)
|
||||
|
|
Loading…
Reference in New Issue