fix grad error of AbstractList to AbstractTuple in split
This commit is contained in:
parent
d4bfde846b
commit
280313da21
|
@ -5234,7 +5234,7 @@ def _tensor_split_sub_tensors(x, indices_or_sections, axis):
|
|||
end[axis] = idx
|
||||
sliced_tensor = strided_slice(x, tuple(begin), tuple(end), strides)
|
||||
sub_tensors.append(sliced_tensor)
|
||||
return sub_tensors
|
||||
return tuple(sub_tensors)
|
||||
|
||||
|
||||
def _tensor_split_sub_int(x, indices_or_sections, axis):
|
||||
|
@ -5321,9 +5321,6 @@ def tensor_split(x, indices_or_sections, axis=0):
|
|||
for item in indices_or_sections:
|
||||
if type(item) is not int:
|
||||
raise TypeError(f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
|
||||
if item < 0:
|
||||
raise TypeError(f"Each element in 'indices_or_sections' should be non-negative,"
|
||||
f" but got {indices_or_sections}.")
|
||||
res = _tensor_split_sub_tensors(x, indices_or_sections, axis)
|
||||
else:
|
||||
raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), " \
|
||||
|
|
Loading…
Reference in New Issue