fix grad error of AbstractList to AbstractTuple in split

This commit is contained in:
ZhidanLiu 2023-02-25 09:36:19 +08:00
parent d4bfde846b
commit 280313da21
1 changed files with 1 additions and 4 deletions

View File

@ -5234,7 +5234,7 @@ def _tensor_split_sub_tensors(x, indices_or_sections, axis):
end[axis] = idx
sliced_tensor = strided_slice(x, tuple(begin), tuple(end), strides)
sub_tensors.append(sliced_tensor)
return sub_tensors
return tuple(sub_tensors)
def _tensor_split_sub_int(x, indices_or_sections, axis):
@ -5321,9 +5321,6 @@ def tensor_split(x, indices_or_sections, axis=0):
for item in indices_or_sections:
if type(item) is not int:
raise TypeError(f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
if item < 0:
raise TypeError(f"Each element in 'indices_or_sections' should be non-negative,"
f" but got {indices_or_sections}.")
res = _tensor_split_sub_tensors(x, indices_or_sections, axis)
else:
raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), " \