fix bugs of split and tensor_split

This commit is contained in:
ZhidanLiu 2023-02-14 11:47:39 +08:00
parent 0bff42957e
commit fc3c4d0df7
5 changed files with 26 additions and 17 deletions

View File

@ -76,6 +76,7 @@
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "not-callable"
"mindspore/mindspore/python/mindspore/hypercomplex" "useless-return"
"mindspore/mindspore/python" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "unidiomatic-typecheck"
# MindData
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"

View File

@ -18,7 +18,7 @@ mindspore.ops.max_unpool1d
- **indices** (Tensor) - 最大值的索引。shape必须与输入 `x` 相同。取值范围需满足 :math:`[0, H_{in} - 1]`
数据类型必须是int32或int64。
- **kernel_size** (Union[int, tuple[int]]) - 池化核尺寸大小。
- **stride** (Union[int, tuple[int]]) - 池化操作的移动步长,若取值为 'None' `stride` 值与 `kernel_size`
- **stride** (Union[int, tuple[int]]) - 池化操作的移动步长,若取值为 '0' '(0)' 或 'None' `stride` 值与 `kernel_size`
相同。默认值None。
- **padding** (Union[int, tuple[int]]) - 填充值。默认值0。
- **output_size** (tuple[int], 可选) - 输出shape。默认值None。

View File

@ -7,7 +7,7 @@ mindspore.ops.tensor_split
参数:
- **x** (Tensor) - 待分割的Tensor。
- **indices_or_sections** (Union[int, tuple(int), list(int)]) - 如果 `indices_or_sections` 是整数类型n输入将沿 `axis` 轴分割成n份。如果输入沿着 `axis` 轴能被n整除那么每个切片的大小相同为 :math:`input.size(axis) / n` 。如果不能被n整除那么前 :math:`input.size(axis) % n` 个切片的大小为 :math:`input.size(axis) // n + 1` ,其余切片的大小为 :math:`input.size(axis) // n` 。
- **indices_or_sections** (Union[int, tuple(int), list(int)]) - 如果 `indices_or_sections` 是整数类型n输入将沿 `axis` 轴分割成n份。如果输入沿着 `axis` 轴能被n整除那么每个切片的大小相同为 :math:`x.size(axis) / n` 。如果不能被n整除那么前 :math:`x.size(axis) % n` 个切片的大小为 :math:`x.size(axis) // n + 1` ,其余切片的大小为 :math:`x.size(axis) // n` 。
如果 `indices_or_sections` 是由int组成list或者tuple那么输入将沿着 `axis` 轴在tuple或list中的索引处切分。例如:math:`indices\_or\_sections=[2, 3]`:math:`axis=0` 将得到切片 :math:`x[:2]` :math:`x[2:3]` ,和 :math:`x[3:]` .
- **axis** (int) - 指定分割轴。默认值0。

View File

@ -4960,7 +4960,7 @@ def _split_int(x, split_size_or_sections, axis):
arr_shape = x.shape
length_along_dim = arr_shape[axis]
if split_size_or_sections > length_along_dim:
res = P.Split(axis, length_along_dim)(x)
res = P.Split(axis, 1)(x)
elif length_along_dim % split_size_or_sections == 0:
sections = length_along_dim // split_size_or_sections
res = P.Split(axis, sections)(x)
@ -5036,12 +5036,16 @@ def split(x, split_size_or_sections, axis=0):
"""
if not isinstance(x, Tensor):
raise TypeError(f'expect `x` is a Tensor, but got {type(x)}')
if not isinstance(axis, int):
if type(axis) is not int:
raise TypeError(f"Type of Argument `axis` should be integer but got {type(axis)}")
axis = _canonicalize_axis(axis, x.ndim)
if isinstance(split_size_or_sections, int):
res = _split_int(x, split_size_or_sections, axis)
if type(split_size_or_sections) is int:
if split_size_or_sections > 0:
res = _split_int(x, split_size_or_sections, axis)
else:
raise ValueError(f"For split, the value of 'split_size_or_sections' must be more than zero, "
f"but got {split_size_or_sections}.")
elif isinstance(split_size_or_sections, (list, tuple)):
for item in split_size_or_sections:
if not isinstance(item, int):
@ -5053,7 +5057,7 @@ def split(x, split_size_or_sections, axis=0):
else:
raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), " \
f"but got {type(split_size_or_sections)}")
return res
return tuple(res)
def tril(input_x, diagonal=0): # pylint: disable=redefined-outer-name
@ -5212,9 +5216,10 @@ def _tensor_split_sub_int(x, indices_or_sections, axis):
length_along_dim = arr_shape[axis]
if indices_or_sections > length_along_dim:
res = P.Split(axis, length_along_dim)(x)
indices_or_sections_n = [i for i in np.arange(length_along_dim, indices_or_sections)]
indices_or_sections_n = [length_along_dim, length_along_dim + 1]
res2 = _tensor_split_sub_tensors(x, indices_or_sections_n, axis)
res += tuple(res2)[1:]
for _ in np.arange(length_along_dim, indices_or_sections):
res += tuple(res2)[1:]
elif length_along_dim % indices_or_sections == 0:
res = P.Split(axis, indices_or_sections)(x)
else:
@ -5240,9 +5245,9 @@ def tensor_split(x, indices_or_sections, axis=0):
indices_or_sections (Union[int, tuple(int), list(int)]):
If `indices_or_sections` is an integer n, input is split into
n sections along dimension `axis`. If input is divisible by n along dimension `axis`, each section will be
of equal size, :math:`input.size(axis) / n` . If input is not divisible by n, the sizes of the first
:math:`input.size(axis) % n` sections will have size :math:`input.size(axis) // n + 1` , and the rest will
have size :math:`input.size(axis) // n` .
of equal size, :math:`x.size(axis) / n` . If input is not divisible by n, the sizes of the first
:math:`x.size(axis) % n` sections will have size :math:`x.size(axis) // n + 1` , and the rest will
have size :math:`x.size(axis) // n` .
If `indices_or_sections` is a list or tuple of ints, then input is split
along dimension `axis` at each of the indices in the list, tuple. For instance,
:math:`indices\_or\_sections=[2, 3]` and :math:`axis=0` would result in the tensors :math:`x[:2]` ,
@ -5273,12 +5278,15 @@ def tensor_split(x, indices_or_sections, axis=0):
if not isinstance(x, Tensor):
raise TypeError(f'expect `x` is a Tensor, but got {type(x)}')
if not isinstance(axis, int):
if type(axis) is not int:
raise TypeError(f"Type of Argument `axis` should be integer but got {type(axis)}")
axis = _canonicalize_axis(axis, x.ndim)
if isinstance(indices_or_sections, int):
res = _tensor_split_sub_int(x, indices_or_sections, axis)
if type(indices_or_sections) is int:
if indices_or_sections > 0:
res = _tensor_split_sub_int(x, indices_or_sections, axis)
else:
raise ValueError(f"For tensor_split, the value of 'indices_or_sections' must be more than zero "
f"but got {indices_or_sections}")
elif isinstance(indices_or_sections, (list, tuple)):
for item in indices_or_sections:
if not isinstance(item, int):

View File

@ -703,7 +703,7 @@ def max_unpool1d(x, indices, kernel_size, stride=None, padding=0, output_size=No
Data type must be in int32 or int64.
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value.
stride (Union[int, tuple[int]]): The distance of kernel moving,
If stride is None, then stride equal to kernel_size. Default: None.
If stride is 0, (0) or None, then stride equal to kernel_size. Default: None.
padding (Union[int, tuple[int]]): The pad value to be filled. Default: 0.
output_size (tuple[int], optional): The output shape. Default: None.
If output_size == (), then the shape of output computed by `kernel_size`, `stride` and `padding`.