!46332 add function and Tensor api: split.

Merge pull request !46332 from ZhidanLiu/split
This commit is contained in:
i-robot 2022-12-06 13:35:45 +00:00 committed by Gitee
commit 07a7dbf69e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 353 additions and 97 deletions

View File

@ -1,6 +1,6 @@
mindspore.Tensor.split
======================
=======================
.. py:method:: mindspore.Tensor.split(axis=0, output_num=1)
.. py:method:: mindspore.Tensor.split(split_size_or_sections, axis=0)
详情请参考 :func:`mindspore.ops.split`

View File

@ -1,21 +1,25 @@
mindspore.ops.split
====================
.. py:function:: mindspore.ops.split(input_x, axis=0, output_num=1)
.. py:function:: mindspore.ops.split(x, split_size_or_sections, axis=0)
根据指定的轴和分割数量对输入Tensor进行分割。
`input_x` Tensor将被分割为相同shape的子Tensor。要求 `input_x.shape(axis)` 可被 `output_num` 整除。
根据指定的轴将输入Tensor切分成块。
参数:
- **input_x** (Tensor) - Tensor的shape为 :math:`(x_1, x_2, ..., x_R)`
- **x** (Tensor) - Tensor的shape为 :math:`(x_1, x_2, ..., x_R)`
- **split_size_or_sections** (Union[int, tuple(int), list(int)]) - 如果 `split_size_or_sections` 是int类型
`x` 将被均匀的切分成块,每块的大小为 `split_size_or_sections` ,若 `x.shape[axis]` 不能被 `split_size_or_sections` 整除,最后一块大小将小于 `split_size_or_sections`
如果 `split_size_or_sections` 是个list类型`x` 将沿 `axis` 轴被切分成 `len(split_size_or_sections)` 块,大小为 `split_size_or_sections`
- **axis** (int) - 指定分割轴。默认值0。
- **output_num** (int) - 指定分割数量。其值为正整数。默认值1。
返回:
tuple[Tensor]每个输出Tensor的shape相同:math:`(y_1, y_2, ..., y_S)` 。数据类型与 `input_x` 的相同
tuple[Tensor]。
异常:
- **TypeError** - `axis``output_num` 不是int。
- **ValueError** - `axis` 超出[-len(`input_x.shape`), len(`input_x.shape`))范围。或 `output_num` 小于或等于0。
- **ValueError** - `input_x.shape(axis)` 不可被 `output_num` 整除。
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `axis` 不是int类型。
- **ValueError** - 参数 `axis` 超出 :math:`-x.dim, x.dim)` 范围。
- **TypeError** - `split_size_or_sections` 中的每个元素不是int类型
- **TypeError** - `split_size_or_sections` 不是inttuple(int)或list(int)。
- **ValueError** - `split_size_or_sections` 的和不等于x.shape[axis]。

View File

@ -8,8 +8,8 @@ mindspore.ops.tensor_split
参数:
- **x** (Tensor) - 待分割的Tensor。
- **indices_or_sections** (Union[int, tuple(int), list(int)]) -
如果`indices_or_sections`是整数类型n输入将沿`axis`轴分割成n份。如果输入沿着`axis`轴能被n整除那么每个切片的大小相同为 :math:`input.size(axis) / n` 。如果不能被n整除那么前 :math:`input.size(axis) % n` 个切片的大小为 :math:`input.size(axis) // n + 1` ,其余切片的大小为 :math:`input.size(axis) // n`
如果`indices_or_sections`是由int组成list或者tuple那么输入将沿着`axis`轴在tuple或list中的索引处切分。例如:math:`indices_or_sections=[2, 3]`:math:`axis=0` 将得到切片 :math:`x[:2]` :math:`x[2:3]` ,和 :math:`x[3:]` .
如果 `indices_or_sections` 是整数类型n输入将沿`axis`轴分割成n份。如果输入沿着`axis`轴能被n整除那么每个切片的大小相同为 :math:`input.size(axis) / n` 。如果不能被n整除那么前 :math:`input.size(axis) % n` 个切片的大小为 :math:`input.size(axis) // n + 1` ,其余切片的大小为 :math:`input.size(axis) // n`
如果 `indices_or_sections` 是由int组成list或者tuple那么输入将沿着`axis`轴在tuple或list中的索引处切分。例如:math:`indices_or_sections=[2, 3]`:math:`axis=0` 将得到切片 :math:`x[:2]` :math:`x[2:3]` ,和 :math:`x[3:]` .
- **axis** (int) - 指定分割轴。默认值0。
返回:
@ -18,7 +18,6 @@ mindspore.ops.tensor_split
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `axis` 不是int类型。
- **TypeError** - `axis` 不是int类型。
- **ValueError** - 参数 `axis` 超出 :math:`[-x.dim, x.dim)` 范围。
- **TypeError** - `indices_or_sections` 中的每个元素不是int类型
- **TypeError** - `indices_or_sections` 不是inttuple(int)或list(int)。

View File

@ -348,7 +348,7 @@ BuiltInTypeMap &GetMethodMap() {
{"to_coo", std::string("to_coo")}, // dense_to_sparse_coo()
{"to_csr", std::string("to_csr")}, // dense_to_sparse_csr()
{"col2im", std::string("col2im")}, // P.Col2Im
{"split", std::string("split")}, // P.Split()
{"split", std::string("split")}, // split
{"tensor_split", std::string("tensor_split")}, // tensor_split
{"vsplit", std::string("vsplit")}, // vsplit
{"hsplit", std::string("hsplit")}, // hsplit

View File

@ -3514,12 +3514,12 @@ def gather(input_x, input_indices, axis):
return F.gather(input_x, input_indices, axis)
def split(input_x, axis=0, output_num=1):
def split(x, split_size_or_sections, axis=0):
"""
Splits the input tensor into output_num of tensors along the given axis and output numbers.
Splits the Tensor into chunks along the given axis.
Refer to :func:`mindspore.ops.split` for more detail.
"""
return F.split(input_x, axis, output_num)
return F.split(x, split_size_or_sections, axis)
def tensor_split(x, indices_or_sections, axis=0):

View File

@ -3581,11 +3581,11 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get("xdivy")()(self, y)
def split(self, axis=0, output_num=1):
def split(self, split_size_or_sections, axis=0):
"""
For details, please refer to :func:`mindspore.ops.split`.
"""
return tensor_operator_registry.get('split')(axis, output_num)(self)
return tensor_operator_registry.get('split')(self, split_size_or_sections, axis)
def tensor_split(self, indices_or_sections, axis=0):
"""
@ -3598,6 +3598,7 @@ class Tensor(Tensor_):
"""
For details, please refer to :func:`mindspore.ops.vsplit`.
"""
self._init_check()
return tensor_operator_registry.get('vsplit')(self, indices_or_sections)

View File

@ -4545,57 +4545,107 @@ def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride):
return c2i(input_x, output_size)
def split(input_x, axis=0, output_num=1):
r"""
Splits the input tensor into output_num of tensors along the given axis and output numbers.
def _split_int(x, split_size_or_sections, axis):
"""
Splits the input tensor `x` into multiple sub-tensors along the axis according to the given `split_size_or_sections`
with int type.
"""
arr_shape = x.shape
length_along_dim = arr_shape[axis]
if split_size_or_sections > length_along_dim:
res = P.Split(axis, length_along_dim)(x)
elif length_along_dim % split_size_or_sections == 0:
sections = length_along_dim // split_size_or_sections
res = P.Split(axis, sections)(x)
else:
num_sections = length_along_dim // split_size_or_sections
length1 = num_sections * split_size_or_sections
length2 = length_along_dim - length1
start1 = _list_comprehensions(rank(x), 0, True)
size1 = _tuple_setitem(arr_shape, axis, length1)
start2 = _tuple_setitem(start1, axis, length1)
size2 = _tuple_setitem(arr_shape, axis, length2)
res = P.Split(axis, num_sections)(tensor_slice(x, start1, size1)) + \
P.Split(axis, 1)(tensor_slice(x, start2, size2))
return res
The `input_x` tensor will be split into equally sized sub-tensors.
This requires that `input_x.shape(axis)` is divisible by `output_num`.
def _split_sub_tensors(x, split_size_or_sections, axis):
"""
Splits the input tensor `x` into multiple sub-tensors along the axis according to the given `split_size_or_sections`
with type of tuple or list.
"""
new_indices = [0]
for i, split_size in enumerate(split_size_or_sections):
new_indices.append(new_indices[i] + split_size)
new_indices = new_indices[1:]
sub_tensors = []
strides = _list_comprehensions(x.ndim, 1, True)
begin = _list_comprehensions(x.ndim, 0)
end = _list_comprehensions(x.shape)
for i, idx in enumerate(new_indices):
begin[axis] = 0 if i == 0 else new_indices[i - 1]
end[axis] = idx
sliced_tensor = strided_slice(x, tuple(begin), tuple(end), strides)
sub_tensors.append(sliced_tensor)
return sub_tensors
def split(x, split_size_or_sections, axis=0):
"""
Splits the Tensor into chunks along the given axis.
Args:
input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
axis (int): Index of the split position. Default: 0.
output_num (int): The number of output tensors. Must be positive int. Default: 1.
x (Tensor): A Tensor to be divided.
split_size_or_sections (Union[int, tuple(int), list(int)]):
If `split_size_or_sections` is an int type, `x` will be split into equally sized chunks, each chunk with
size `split_size_or_sections`. Last chunk will be smaller than `split_size_or_sections` if `x.shape[axis]`
is not divisible by `split_size_or_sections`.
If `split_size_or_sections` is a list type, then `x` will be split into len(split_size_or_sections)
chunks with sizes `split_size_or_sections` along the given `axis`.
axis (int): The axis along which to split. Default: 0.
Returns:
tuple[Tensor], the shape of each output tensor is the same, which is
:math:`(y_1, y_2, ..., y_S)`. And the data type is the same with `input_x`.
A tuple of sub-tensors.
Raises:
TypeError: If `axis` or `output_num` is not an int.
ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)),
or if the `output_num` is less than or equal to 0.
ValueError: If `input_x.shape(axis)` is not divisible by `output_num`.
TypeError: If argument `x` is not Tensor.
TypeError: If argument `axis` is not Tensor.
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)` .
TypeError: If each element in 'split_size_or_sections' is not integer.
TypeError: If argument `indices_or_sections` is not int, tuple(int) or list(int).
ValueError: The sum of 'split_size_or_sections' is not equal to x.shape[axis].
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32)
>>> print(x)
[[1 1 1 1]
[2 2 2 2]]
>>> output = ops.split(x, 1, 2)
>>> input_x = np.arange(9).astype("float32")
>>> output = ops.split(Tensor(input_x), 3)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Int32, value=
[[1, 1],
[2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value=
[[1, 1],
[2, 2]]))
>>> output = ops.split(x, 1, 4)
>>> print(output)
(Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
[2]]), Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
[2]]), Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
[2]]), Tensor(shape=[2, 1], dtype=Int32, value=
[[1],
[2]]))
(Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
"""
split_ = _get_cache_prim(P.Split)(axis, output_num)
return split_(input_x)
if not isinstance(x, Tensor):
raise TypeError(f'expect `x` is a Tensor, but got {type(x)}')
_ = validator.check_axis_type(axis, True, False, False)
axis = _canonicalize_axis(axis, x.ndim)
if isinstance(split_size_or_sections, int):
res = _split_int(x, split_size_or_sections, axis)
elif isinstance(split_size_or_sections, (list, tuple)):
for item in split_size_or_sections:
if not isinstance(item, int):
raise TypeError(f"Each element in 'split_size_or_sections' should be integer, but got {type(item)}.")
if sum(split_size_or_sections) != x.shape[axis]:
raise ValueError(f"The sum of 'split_size_or_sections' should be equal to {x.shape[axis]}, "
f"but got {sum(split_size_or_sections)}.")
res = _split_sub_tensors(x, split_size_or_sections, axis)
else:
raise TypeError(f"Type of Argument `split_size_or_sections` should be integer, tuple(int) or list(int), " \
f"but got {type(split_size_or_sections)}")
return res
@constexpr
@ -4606,6 +4656,7 @@ def _canonicalize_axis(axis, ndim):
Args:
axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
ndim (int): The number of dimensions of the tensor.
Return:
Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
"""
@ -4672,6 +4723,7 @@ def _tensor_split_sub_tensors(x, indices_or_sections, axis):
length_along_dim = x.shape[axis]
indices_or_sections = tuple(indices_or_sections)
indices_or_sections += (length_along_dim,)
sub_tensors = []
strides = _list_comprehensions(x.ndim, 1, True)
begin = _list_comprehensions(x.ndim, 0)
@ -4735,7 +4787,7 @@ def tensor_split(x, indices_or_sections, axis=0):
Raises:
TypeError: If argument `x` is not Tensor.
TypeError: If argument `axis` is not Tensor.
TypeError: If argument `axis` is not int.
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)` .
TypeError: If each element in 'indices_or_sections' is not integer.
TypeError: If argument `indices_or_sections` is not int, tuple(int) or list(int).
@ -4747,12 +4799,9 @@ def tensor_split(x, indices_or_sections, axis=0):
>>> input_x = np.arange(9).astype("float32")
>>> output = ops.tensor_split(Tensor(input_x), 3)
>>> print(output)
(Tensor(shape=[3], dtype=Float32,
value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
Tensor(shape=[3], dtype=Float32,
value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
Tensor(shape=[3], dtype=Float32,
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
(Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
"""
if not isinstance(x, Tensor):
raise TypeError(f'expect `x` is a Tensor, but got {type(x)}')
@ -4793,12 +4842,9 @@ def vsplit(x, indices_or_sections):
>>> input_x = np.arange(9).reshape((3, 3)).astype('float32')
>>> output = ops.vsplit(Tensor(input_x), 3)
>>> print(output)
(Tensor(shape=[1, 3], dtype=Float32,
value=[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]]),
Tensor(shape=[1, 3], dtype=Float32,
value=[[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]),
Tensor(shape=[1, 3], dtype=Float32,
value=[[ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]]))
(Tensor(shape=[1, 3], dtype=Float32, value=[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]]),
Tensor(shape=[1, 3], dtype=Float32, value=[[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]),
Tensor(shape=[1, 3], dtype=Float32, value=[[ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]]))
"""
return tensor_split(x, indices_or_sections, 0)
@ -4822,15 +4868,9 @@ def hsplit(x, indices_or_sections):
>>> input_x = np.arange(6).reshape((2, 3)).astype('float32')
>>> output = ops.hsplit(Tensor(input_x), 3)
>>> print(output)
(Tensor(shape=[2, 1], dtype=Float32,
value=[[ 0.00000000e+00],
[ 3.00000000e+00]]),
Tensor(shape=[2, 1], dtype=Float32,
value=[[ 1.00000000e+00],
[ 4.00000000e+00]]),
Tensor(shape=[2, 1], dtype=Float32,
value=[[ 2.00000000e+00],
[ 5.00000000e+00]]))
(Tensor(shape=[2, 1], dtype=Float32, value=[[ 0.00000000e+00], [ 3.00000000e+00]]),
Tensor(shape=[2, 1], dtype=Float32, value=[[ 1.00000000e+00], [ 4.00000000e+00]]),
Tensor(shape=[2, 1], dtype=Float32, value=[[ 2.00000000e+00], [ 5.00000000e+00]]))
"""
return tensor_split(x, indices_or_sections, 1)
@ -4854,15 +4894,9 @@ def dsplit(x, indices_or_sections):
>>> input_x = np.arange(6).reshape((1, 2, 3)).astype('float32')
>>> output = ops.dsplit(Tensor(input_x), 3)
>>> print(output)
(Tensor(shape=[1, 2, 1], dtype=Float32,
value=[[[ 0.00000000e+00],
[ 3.00000000e+00]]]),
Tensor(shape=[1, 2, 1], dtype=Float32,
value=[[[ 1.00000000e+00],
[ 4.00000000e+00]]]),
Tensor(shape=[1, 2, 1], dtype=Float32,
value=[[[ 2.00000000e+00],
[ 5.00000000e+00]]]))
(Tensor(shape=[1, 2, 1], dtype=Float32, value=[[[ 0.00000000e+00], [ 3.00000000e+00]]]),
Tensor(shape=[1, 2, 1], dtype=Float32, value=[[[ 1.00000000e+00], [ 4.00000000e+00]]]),
Tensor(shape=[1, 2, 1], dtype=Float32, value=[[[ 2.00000000e+00], [ 5.00000000e+00]]]))
"""
return tensor_split(x, indices_or_sections, 2)

View File

@ -200,7 +200,7 @@ tensor_operator_registry.register('fill', P.Fill)
tensor_operator_registry.register('tile', P.Tile)
tensor_operator_registry.register('logit', logit)
tensor_operator_registry.register('sum', P.ReduceSum)
tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('split', split)
tensor_operator_registry.register('tensor_split', tensor_split)
tensor_operator_registry.register('vsplit', vsplit)
tensor_operator_registry.register('hsplit', hsplit)

View File

@ -21,7 +21,7 @@ import mindspore.nn as nn
import mindspore.ops as ops
class SplitNet(nn.Cell):
class TensorSplitNet(nn.Cell):
def construct(self, x, indices_or_sections, axis=0):
out = ops.tensor_split(x, indices_or_sections, axis)
return out
@ -42,7 +42,7 @@ def test_f_tensor_split_int(mode):
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
net = TensorSplitNet()
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
indices_or_sections = 3
@ -67,7 +67,7 @@ def test_f_tensor_split_list(mode):
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
net = TensorSplitNet()
a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
indices_or_sections = [2, 4]
@ -77,6 +77,106 @@ def test_f_tensor_split_list(mode):
assert np.allclose(res.asnumpy(), exp)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_tensor_split_list2(mode):
"""
Feature: tensor_split
Description: Verify the result of tensor_split when `indices_or_sections` is out of normal length.
Expectation: success
"""
ms.set_context(mode=mode)
a = np.arange(10).reshape((5, 2))
indices_or_sections = [1, 4, 7]
net = TensorSplitNet()
x = ms.Tensor(a, dtype=ms.int64)
out = net(x, indices_or_sections)
expect = np.array_split(a, indices_or_sections)
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_tensor_split_list3(mode):
"""
Feature: tensor_split
Description: Verify the result of tensor_split when `indices_or_sections` has negative.
Expectation: success
"""
ms.set_context(mode=mode)
a = np.arange(10).reshape((5, 2))
indices_or_sections = [-5, 4, 3, 7]
net = TensorSplitNet()
x = ms.Tensor(a, dtype=ms.int64)
out = net(x, indices_or_sections)
expect = np.array_split(a, indices_or_sections)
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_tensor_split_list4(mode):
"""
Feature: tensor_split
Description: Verify the result of tensor_split when `indices_or_sections` has negative number and out of range.
Expectation: success
"""
ms.set_context(mode=mode)
a = np.arange(12)
indices_or_sections = [-18, -14, -10]
net = TensorSplitNet()
x = ms.Tensor(a, dtype=ms.int64)
out = net(x, indices_or_sections)
expect = np.array_split(a, indices_or_sections)
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_tensor_split_list5(mode):
"""
Feature: tensor_split
Description: Verify the result of tensor_split when `indices_or_sections` has special order.
Expectation: success
"""
ms.set_context(mode=mode)
a = np.arange(12)
indices_or_sections = [-18, -10, -14, 2]
net = TensorSplitNet()
x = ms.Tensor(a, dtype=ms.int64)
out = net(x, indices_or_sections)
expect = np.array_split(a, indices_or_sections)
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
class VSplitNet(nn.Cell):
def construct(self, x, indices_or_sections):
out = ops.vsplit(x, indices_or_sections)
@ -155,7 +255,7 @@ def test_f_hsplit_int(mode):
"""
ms.set_context(mode=mode)
net = HSplitNet()
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
indices_or_sections = 3
out = net(x, indices_or_sections)
@ -243,3 +343,62 @@ def test_f_dsplit_list(mode):
expect = np.array_split(a, indices_or_sections, axis=2)
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
class SplitNet(nn.Cell):
def construct(self, x, split_size_or_sections, axis=0):
out = ops.split(x, split_size_or_sections, axis)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_split_int(mode):
"""
Feature: split
Description: Verify the result of split.
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
split_size_or_sections = 5
out = net(x, split_size_or_sections)
expect = [np.array(np.arange(10).reshape((5, 2)), dtype=np.float32),
np.array(np.arange(10, 20).reshape((5, 2)), dtype=np.float32)]
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_split_list(mode):
"""
Feature: split
Description: Verify the result of split.
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
split_size_or_sections = [2, 3, 5]
out = net(x, split_size_or_sections, axis=1)
expect = [np.array([[0, 1], [10, 11]], dtype=np.float32),
np.array([[2, 3, 4], [12, 13, 14]], dtype=np.float32),
np.array([[5, 6, 7, 8, 9], [15, 16, 17, 18, 19]], dtype=np.float32)]
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)

View File

@ -20,7 +20,7 @@ import mindspore as ms
import mindspore.nn as nn
class SplitNet(nn.Cell):
class TensorSplitNet(nn.Cell):
def construct(self, x, indices_or_sections, axis=0):
out = x.tensor_split(indices_or_sections, axis)
return out
@ -41,7 +41,7 @@ def test_f_tensor_split_int(mode):
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
net = TensorSplitNet()
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
indices_or_sections = 3
@ -66,7 +66,7 @@ def test_f_tensor_split_list(mode):
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
net = TensorSplitNet()
a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
indices_or_sections = [2, 4]
@ -242,3 +242,62 @@ def test_f_dsplit_list(mode):
expect = np.array_split(a, indices_or_sections, axis=2)
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
class SplitNet(nn.Cell):
def construct(self, x, split_size_or_sections, axis=0):
out = x.split(split_size_or_sections, axis)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_split_int(mode):
"""
Feature: split
Description: Verify the result of split.
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
split_size_or_sections = 5
out = net(x, split_size_or_sections)
expect = [np.array(np.arange(10).reshape((5, 2)), dtype=np.float32),
np.array(np.arange(10, 20).reshape((5, 2)), dtype=np.float32)]
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_f_split_list(mode):
"""
Feature: split
Description: Verify the result of split.
Expectation: success
"""
ms.set_context(mode=mode)
net = SplitNet()
a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32)
x = ms.Tensor(a, dtype=ms.float32)
split_size_or_sections = [2, 3, 5]
out = net(x, split_size_or_sections, axis=1)
expect = [np.array([[0, 1], [10, 11]], dtype=np.float32),
np.array([[2, 3, 4], [12, 13, 14]], dtype=np.float32),
np.array([[5, 6, 7, 8, 9], [15, 16, 17, 18, 19]], dtype=np.float32)]
for res, exp in zip(out, expect):
assert np.allclose(res.asnumpy(), exp)