From 6875db69f077d54725a806973b14d5ef14de0988 Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Tue, 22 Nov 2022 14:40:11 +0800 Subject: [PATCH] add split series api --- docs/api/api_python/mindspore.ops.rst | 4 + .../Tensor/mindspore.Tensor.dsplit.rst | 6 + .../Tensor/mindspore.Tensor.hsplit.rst | 6 + .../Tensor/mindspore.Tensor.tensor_split.rst | 6 + .../Tensor/mindspore.Tensor.vsplit.rst | 6 + .../api_python/mindspore/mindspore.Tensor.rst | 4 + .../ops/mindspore.ops.func_dsplit.rst | 13 + .../ops/mindspore.ops.func_hsplit.rst | 13 + .../ops/mindspore.ops.func_tensor_split.rst | 26 ++ .../ops/mindspore.ops.func_vsplit.rst | 13 + docs/api/api_python_en/Tensor_list.rst | 4 + docs/api/api_python_en/mindspore.ops.rst | 4 + mindspore/ccsrc/pipeline/jit/resource.cc | 4 + .../_extends/parse/standard_method.py | 33 +++ mindspore/python/mindspore/common/tensor.py | 28 ++ .../python/mindspore/ops/function/__init__.py | 4 + .../mindspore/ops/function/array_func.py | 273 ++++++++++++++++++ mindspore/python/mindspore/ops/functional.py | 5 +- tests/st/ops/test_split.py | 245 ++++++++++++++++ tests/st/tensor/test_split.py | 244 ++++++++++++++++ 20 files changed, 940 insertions(+), 1 deletion(-) create mode 100644 docs/api/api_python/mindspore/Tensor/mindspore.Tensor.dsplit.rst create mode 100644 docs/api/api_python/mindspore/Tensor/mindspore.Tensor.hsplit.rst create mode 100644 docs/api/api_python/mindspore/Tensor/mindspore.Tensor.tensor_split.rst create mode 100644 docs/api/api_python/mindspore/Tensor/mindspore.Tensor.vsplit.rst create mode 100644 docs/api/api_python/ops/mindspore.ops.func_dsplit.rst create mode 100644 docs/api/api_python/ops/mindspore.ops.func_hsplit.rst create mode 100644 docs/api/api_python/ops/mindspore.ops.func_tensor_split.rst create mode 100644 docs/api/api_python/ops/mindspore.ops.func_vsplit.rst create mode 100644 tests/st/ops/test_split.py create mode 100644 tests/st/tensor/test_split.py diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index 3c7b8827c3e..1276c18b566 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -372,6 +372,7 @@ Array操作 mindspore.ops.diag mindspore.ops.diagonal mindspore.ops.dyn_shape + mindspore.ops.dsplit mindspore.ops.expand mindspore.ops.expand_dims mindspore.ops.flip @@ -382,6 +383,7 @@ Array操作 mindspore.ops.gather_d mindspore.ops.gather_elements mindspore.ops.gather_nd + mindspore.ops.hsplit mindspore.ops.index_add mindspore.ops.index_fill mindspore.ops.inplace_add @@ -428,6 +430,7 @@ Array操作 mindspore.ops.tensor_scatter_mul mindspore.ops.tensor_scatter_sub mindspore.ops.tensor_scatter_elements + mindspore.ops.tensor_split mindspore.ops.tile mindspore.ops.top_k mindspore.ops.transpose @@ -442,6 +445,7 @@ Array操作 mindspore.ops.unsorted_segment_sum mindspore.ops.unsqueeze mindspore.ops.unstack + mindspore.ops.vsplit mindspore.ops.where 类型转换 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.dsplit.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.dsplit.rst new file mode 100644 index 00000000000..5b1b6211365 --- /dev/null +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.dsplit.rst @@ -0,0 +1,6 @@ +mindspore.Tensor.dsplit +======================== + +.. py:method:: mindspore.Tensor.dsplit(indices_or_sections) + + 详情请参考 :func:`mindspore.ops.dsplit`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.hsplit.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.hsplit.rst new file mode 100644 index 00000000000..29410d69b89 --- /dev/null +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.hsplit.rst @@ -0,0 +1,6 @@ +mindspore.Tensor.hsplit +======================== + +.. py:method:: mindspore.Tensor.hsplit(indices_or_sections) + + 详情请参考 :func:`mindspore.ops.hsplit`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.tensor_split.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.tensor_split.rst new file mode 100644 index 00000000000..d6c2626d2f9 --- /dev/null +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.tensor_split.rst @@ -0,0 +1,6 @@ +mindspore.Tensor.tensor_split +============================== + +.. py:method:: mindspore.Tensor.tensor_split(indices_or_sections, axis=0) + + 详情请参考 :func:`mindspore.ops.tensor_split`。 diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.vsplit.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.vsplit.rst new file mode 100644 index 00000000000..6f671f05a3e --- /dev/null +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.vsplit.rst @@ -0,0 +1,6 @@ +mindspore.Tensor.vsplit +======================== + +.. py:method:: mindspore.Tensor.vsplit(indices_or_sections) + + 详情请参考 :func:`mindspore.ops.vsplit`。 diff --git a/docs/api/api_python/mindspore/mindspore.Tensor.rst b/docs/api/api_python/mindspore/mindspore.Tensor.rst index 45f9b408831..a0d3af3371f 100644 --- a/docs/api/api_python/mindspore/mindspore.Tensor.rst +++ b/docs/api/api_python/mindspore/mindspore.Tensor.rst @@ -85,6 +85,7 @@ mindspore.Tensor mindspore.Tensor.diagonal mindspore.Tensor.div mindspore.Tensor.divide + mindspore.Tensor.dsplit mindspore.Tensor.dtype mindspore.Tensor.equal mindspore.Tensor.erf @@ -118,6 +119,7 @@ mindspore.Tensor mindspore.Tensor.hardshrink mindspore.Tensor.has_init mindspore.Tensor.heaviside + mindspore.Tensor.hsplit mindspore.Tensor.hypot mindspore.Tensor.i0 mindspore.Tensor.igamma @@ -230,6 +232,7 @@ mindspore.Tensor mindspore.Tensor.take mindspore.Tensor.tan mindspore.Tensor.tanh + mindspore.Tensor.tensor_split mindspore.Tensor.tile mindspore.Tensor.to mindspore.Tensor.to_coo @@ -250,6 +253,7 @@ mindspore.Tensor mindspore.Tensor.unsqueeze mindspore.Tensor.var mindspore.Tensor.view + mindspore.Tensor.vsplit mindspore.Tensor.where mindspore.Tensor.xdivy mindspore.Tensor.xlogy diff --git a/docs/api/api_python/ops/mindspore.ops.func_dsplit.rst b/docs/api/api_python/ops/mindspore.ops.func_dsplit.rst new file mode 100644 index 00000000000..4a22eb68367 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_dsplit.rst @@ -0,0 +1,13 @@ +mindspore.ops.dsplit +===================== + +.. py:function:: mindspore.ops.dsplit(x, indices_or_sections) + + 沿着第三轴将输入Tensor分割成多个子Tensor。等同于 `axis=2` 时的 `ops.tensor_split` 。 + + 参数: + - **x** (Tensor) - 待分割的Tensor。 + - **indices_or_sections** (Union[int, tuple(int), list(int)]) - 参考 :func:`mindspore.ops.tensor_split`. + + 返回: + tuple[Tensor]。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_hsplit.rst b/docs/api/api_python/ops/mindspore.ops.func_hsplit.rst new file mode 100644 index 00000000000..5849251969a --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_hsplit.rst @@ -0,0 +1,13 @@ +mindspore.ops.hsplit +===================== + +.. py:function:: mindspore.ops.hsplit(x, indices_or_sections) + + 水平地将输入Tensor分割成多个子Tensor。等同于 `axis=1` 时的 `ops.tensor_split` 。 + + 参数: + - **x** (Tensor) - 待分割的Tensor。 + - **indices_or_sections** (Union[int, tuple(int), list(int)]) - 参考 :func:`mindspore.ops.tensor_split`. + + 返回: + tuple[Tensor]。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_split.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_split.rst new file mode 100644 index 00000000000..2b4a957048e --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_split.rst @@ -0,0 +1,26 @@ +mindspore.ops.tensor_split +=========================== + +.. py:function:: mindspore.ops.tensor_split(x, indices_or_sections, axis=0) + + 根据指定的轴将输入Tensor进行分割成多个子tensor。 + + 参数: + - **x** (Tensor) - 待分割的Tensor。 + - **indices_or_sections** (Union[int, tuple(int), list(int)]) - + 如果`indices_or_sections`是整数类型n,输入将沿`axis`轴分割成n份。如果输入沿着`axis`轴能被n整除,那么每个切片的大小相同为 :math:`input.size(axis) / n` 。如果不能被n整除,那么前 :math:`input.size(axis) % n` 个切片的大小为 :math:`input.size(axis) // n + 1` ,其余切片的大小为 :math:`input.size(axis) // n` 。 + 如果`indices_or_sections`是由int组成list或者tuple,那么输入将沿着`axis`轴在tuple或list中的索引处切分。例如::math:`indices_or_sections=[2, 3]` 和 :math:`axis=0` 将得到切片 :math:`x[:2]` , :math:`x[2:3]` ,和 :math:`x[3:]` . + - **axis** (int) - 指定分割轴。默认值:0。 + + 返回: + tuple[Tensor]。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `axis` 不是int类型。 + - **TypeError** - `axis` 不是int类型。 + - **ValueError** - 参数 `axis` 超出 :math:`-x.dim, x.dim)` 范围。 + - **TypeError** - `indices_or_sections` 中的每个元素不是int类型 + - **TypeError** - `indices_or_sections` 不是int,tuple(int)或list(int)。 + + diff --git a/docs/api/api_python/ops/mindspore.ops.func_vsplit.rst b/docs/api/api_python/ops/mindspore.ops.func_vsplit.rst new file mode 100644 index 00000000000..63868ce40b7 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_vsplit.rst @@ -0,0 +1,13 @@ +mindspore.ops.vsplit +===================== + +.. py:function:: mindspore.ops.vsplit(x, indices_or_sections) + + 垂直地将输入Tensor分割成多个子Tensor。等同于 `axis=0` 时的 `ops.tensor_split` 。 + + 参数: + - **x** (Tensor) - 待分割的Tensor。 + - **indices_or_sections** (Union[int, tuple(int), list(int)]) - 参考 :func:`mindspore.ops.tensor_split`. + + 返回: + tuple[Tensor]。 diff --git a/docs/api/api_python_en/Tensor_list.rst b/docs/api/api_python_en/Tensor_list.rst index 80e0272af45..2161f091d5a 100644 --- a/docs/api/api_python_en/Tensor_list.rst +++ b/docs/api/api_python_en/Tensor_list.rst @@ -91,6 +91,7 @@ mindspore.Tensor.diagonal mindspore.Tensor.div mindspore.Tensor.divide + mindspore.Tensor.dsplit mindspore.Tensor.dtype mindspore.Tensor.equal mindspore.Tensor.erf @@ -124,6 +125,7 @@ mindspore.Tensor.hardshrink mindspore.Tensor.has_init mindspore.Tensor.heaviside + mindspore.Tensor.hsplit mindspore.Tensor.hypot mindspore.Tensor.i0 mindspore.Tensor.igamma @@ -236,6 +238,7 @@ mindspore.Tensor.take mindspore.Tensor.tan mindspore.Tensor.tanh + mindspore.Tensor.tensor_split mindspore.Tensor.tile mindspore.Tensor.to mindspore.Tensor.to_coo @@ -259,6 +262,7 @@ mindspore.Tensor.where mindspore.Tensor.xdivy mindspore.Tensor.xlogy + mindspore.Tensor.vsplit {% elif fullname=="mindspore.nn.Cell" %} {{ fullname | underline }} diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index 4c8e81d42a6..2549b1808e3 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -371,6 +371,7 @@ Array Operation mindspore.ops.count_nonzero mindspore.ops.diag mindspore.ops.diagonal + mindspore.ops.dsplit mindspore.ops.dyn_shape mindspore.ops.expand mindspore.ops.expand_dims @@ -382,6 +383,7 @@ Array Operation mindspore.ops.gather_d mindspore.ops.gather_elements mindspore.ops.gather_nd + mindspore.ops.hsplit mindspore.ops.index_add mindspore.ops.index_fill mindspore.ops.inplace_add @@ -428,6 +430,7 @@ Array Operation mindspore.ops.tensor_scatter_mul mindspore.ops.tensor_scatter_sub mindspore.ops.tensor_scatter_elements + mindspore.ops.tensor_split mindspore.ops.tile mindspore.ops.top_k mindspore.ops.transpose @@ -442,6 +445,7 @@ Array Operation mindspore.ops.unsorted_segment_sum mindspore.ops.unsqueeze mindspore.ops.unstack + mindspore.ops.vsplit mindspore.ops.where Type Conversion diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 9b56e3141aa..3ff74adbbdc 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -332,6 +332,10 @@ BuiltInTypeMap &GetMethodMap() { {"to_csr", std::string("to_csr")}, // dense_to_sparse_csr() {"col2im", std::string("col2im")}, // P.Col2Im {"split", std::string("split")}, // P.Split() + {"tensor_split", std::string("tensor_split")}, // tensor_split + {"vsplit", std::string("vsplit")}, // vsplit + {"hsplit", std::string("hsplit")}, // hsplit + {"dsplit", std::string("dsplit")}, // dplit {"random_categorical", std::string("random_categorical")}, // P.RandomCategorical {"xlogy", std::string("xlogy")}, // P.Xlogy() {"erf", std::string("erf")}, // P.Erf() diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 1f79b96f10c..c6a6e0b81e2 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -3462,6 +3462,39 @@ def split(input_x, axis=0, output_num=1): return F.split(input_x, axis, output_num) +def tensor_split(x, indices_or_sections, axis=0): + """ + Splits a tensor into multiple sub-tensors along the given axis. + Refer to :func:`mindspore.ops.tensor_split` for more detail. + """ + return F.tensor_split(x, indices_or_sections, axis=axis) + + +def vsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors vertically. It is equivalent to `ops.tensor_split` with :math:`axis=0` . + Refer to :func:`mindspore.ops.vsplit` for more detail. + """ + return F.vsplit(x, indices_or_sections) + + +def hsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors horizontally. It is equivalent to `ops.tensor_split` with :math:`axis=1` . + Refer to :func:`mindspore.ops.hsplit` for more detail. + """ + return F.hsplit(x, indices_or_sections) + + +def dsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors along the 3rd axis. + It is equivalent to `ops.tensor_split` with :math:`axis=2` . + Refer to :func:`mindspore.ops.tensor_split` for more detail. + """ + return F.dsplit(x, indices_or_sections) + + def xlogy(x, y): r""" Computes the first input tensor multiplied by the logarithm of second input tensor element-wise. diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index c8e0acabb77..e7007ea0ccf 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -3541,6 +3541,34 @@ class Tensor(Tensor_): """ return tensor_operator_registry.get('split')(axis, output_num)(self) + def tensor_split(self, indices_or_sections, axis=0): + """ + For details, please refer to :func:`mindspore.ops.tensor_split`. + """ + self._init_check() + return tensor_operator_registry.get('tensor_split')(self, indices_or_sections, axis) + + def vsplit(self, indices_or_sections): + """ + For details, please refer to :func:`mindspore.ops.vsplit`. + """ + self._init_check() + return tensor_operator_registry.get('vsplit')(self, indices_or_sections) + + def hsplit(self, indices_or_sections): + """ + For details, please refer to :func:`mindspore.ops.hsplit`. + """ + self._init_check() + return tensor_operator_registry.get('hsplit')(self, indices_or_sections) + + def dsplit(self, indices_or_sections): + """ + For details, please refer to :func:`mindspore.ops.dsplit`. + """ + self._init_check() + return tensor_operator_registry.get('dsplit')(self, indices_or_sections) + def xlogy(self, y): r""" For details, please refer to :func:`mindspore.ops.xlogy`. diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index d054f308b48..9ef3b3d88d0 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -113,6 +113,10 @@ from .array_func import ( unsorted_segment_sum, col2im, split, + tensor_split, + vsplit, + hsplit, + dsplit, index_fill, max, argmax, diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 64a512011a8..2477512ec12 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -4476,6 +4476,275 @@ def split(input_x, axis=0, output_num=1): return split_(input_x) +@constexpr +def _canonicalize_axis(axis, ndim): + """ + Check axes are within the number of dimensions of tensor x and normalize the negative axes. + + Args: + axis (Union[int, tuple(int), list(int)]): Axes of the tensor. + ndim (int): The number of dimensions of the tensor. + Return: + Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple. + """ + if isinstance(axis, int): + axis = [axis] + for ax in axis: + if not isinstance(ax, int): + raise TypeError(f'axis should be integers, not {type(ax)}') + if not -ndim <= ax < ndim: + raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}') + + def canonicalizer(ax): + return ax + ndim if ax < 0 else ax + + axis = tuple([canonicalizer(ax) for ax in axis]) + if all(axis.count(el) <= 1 for el in axis): + return tuple(sorted(axis)) if len(axis) > 1 else axis[0] + raise ValueError(f"duplicate axis in {axis}.") + + +@constexpr +def _list_comprehensions(obj, item=None, return_tuple=False): + """ + Generates a new list or tuple by list comprehension. + + Args: + obj (Union[int, list, tuple]): + If integer, it will be the length of the returned tuple/list. + item: The value to be filled. Default: None. + If None, the values in the new list/tuple are the same as obj + or range(obj) when obj is integer. + return_tuple(bool): If true, returns tuple, else returns list. + + Returns: + List or tuple. + """ + lst = obj + if isinstance(obj, int): + lst = np.arange(obj) + if item is None: + res = list(lst) + else: + res = [item for _ in lst] + if return_tuple: + return tuple(res) + return res + + +@constexpr +def _tuple_setitem(tup, idx, value): + """ + Returns a tuple with specified `idx` set to `value`. + """ + tup = list(tup) + tup[idx] = value + return tuple(tup) + + +def _tensor_split_sub_tensors(x, indices_or_sections, axis): + """ + Splits the input tensor `x` into multiple sub-tensors along the axis according to the given `indices_or_sections` + with type of tuple or list. + """ + length_along_dim = x.shape[axis] + indices_or_sections = tuple(indices_or_sections) + indices_or_sections += (length_along_dim,) + sub_tensors = [] + strides = _list_comprehensions(x.ndim, 1, True) + begin = _list_comprehensions(x.ndim, 0) + end = _list_comprehensions(x.shape) + for i, idx in enumerate(indices_or_sections): + begin[axis] = 0 if i == 0 else indices_or_sections[i - 1] + end[axis] = idx + sliced_tensor = strided_slice(x, tuple(begin), tuple(end), strides) + sub_tensors.append(sliced_tensor) + return sub_tensors + + +def _tensor_split_sub_int(x, indices_or_sections, axis): + """ + Splits the input tensor `x` into multiple sub-tensors along the axis according to the given `indices_or_sections` + with type if int. + """ + arr_shape = x.shape + length_along_dim = arr_shape[axis] + if indices_or_sections > length_along_dim: + res = P.Split(axis, length_along_dim)(x) + indices_or_sections_n = [i for i in np.arange(length_along_dim, indices_or_sections)] + res2 = _tensor_split_sub_tensors(x, indices_or_sections_n, axis) + res += tuple(res2)[1:] + elif length_along_dim % indices_or_sections == 0: + res = P.Split(axis, indices_or_sections)(x) + else: + num_long_tensor = length_along_dim % indices_or_sections + num_short_tensor = indices_or_sections - num_long_tensor + length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1) + length2 = length_along_dim - length1 + start1 = _list_comprehensions(rank(x), 0, True) + size1 = _tuple_setitem(arr_shape, axis, length1) + start2 = _tuple_setitem(start1, axis, length1) + size2 = _tuple_setitem(arr_shape, axis, length2) + res = P.Split(axis, num_long_tensor)(tensor_slice(x, start1, size1)) + \ + P.Split(axis, num_short_tensor)(tensor_slice(x, start2, size2)) + return res + + +def tensor_split(x, indices_or_sections, axis=0): + """ + Splits a tensor into multiple sub-tensors along the given axis. + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): + If `indices_or_sections` is an integer n, input is split into + n sections along dimension `axis`. If input is divisible by n along dimension `axis`, each section will be + of equal size, :math:`input.size(axis) / n` . If input is not divisible by n, the sizes of the first + :math:`input.size(axis) % n` sections will have size :math:`input.size(axis) // n + 1` , and the rest will + have size :math:`input.size(axis) // n` . + If `indices_or_sections` is a list or tuple of ints, then input is split + along dimension `axis` at each of the indices in the list, tuple. For instance, + :math:`indices_or_sections=[2, 3]` and :math:`axis=0` would result in the tensors :math:`x[:2]` , + :math:`x[2:3]` , and :math:`x[3:]` . + axis (int): The axis along which to split. Default: 0. + + Returns: + A tuple of sub-tensors. + + Raises: + TypeError: If argument `x` is not Tensor. + TypeError: If argument `axis` is not Tensor. + ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)` . + TypeError: If each element in 'indices_or_sections' is not integer. + TypeError: If argument `indices_or_sections` is not int, tuple(int) or list(int). + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x = np.arange(9).astype("float32") + >>> output = ops.tensor_split(Tensor(input_x), 3) + >>> print(output) + (Tensor(shape=[3], dtype=Float32, + value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]), + Tensor(shape=[3], dtype=Float32, + value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]), + Tensor(shape=[3], dtype=Float32, + value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) + """ + if not isinstance(x, Tensor): + raise TypeError(f'expect `x` is a Tensor, but got {type(x)}') + + _ = validator.check_axis_type(axis, True, False, False) + axis = _canonicalize_axis(axis, x.ndim) + if isinstance(indices_or_sections, int): + res = _tensor_split_sub_int(x, indices_or_sections, axis) + + elif isinstance(indices_or_sections, (list, tuple)): + for item in indices_or_sections: + if not isinstance(item, int): + raise TypeError(f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.") + res = _tensor_split_sub_tensors(x, indices_or_sections, axis) + else: + raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), " \ + f"but got {type(indices_or_sections)}") + + return res + + +def vsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors vertically. + It is equivalent to `ops.tensor_split` with :math:`axis=0` . + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): See argument in :func:`mindspore.ops.tensor_split`. + + Returns: + A list of sub-tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x = np.arange(9).reshape((3, 3)).astype('float32') + >>> output = ops.vsplit(Tensor(input_x), 3) + >>> print(output) + (Tensor(shape=[1, 3], dtype=Float32, + value=[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]]), + Tensor(shape=[1, 3], dtype=Float32, + value=[[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]), + Tensor(shape=[1, 3], dtype=Float32, + value=[[ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])) + """ + return tensor_split(x, indices_or_sections, 0) + + +def hsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors horizontally. + It is equivalent to `ops.tensor_split` with :math:`axis=1` . + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): See argument in :func:`mindspore.ops.tensor_split`. + + Returns: + A list of sub-tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x = np.arange(6).reshape((2, 3)).astype('float32') + >>> output = ops.hsplit(Tensor(input_x), 3) + >>> print(output) + (Tensor(shape=[2, 1], dtype=Float32, + value=[[ 0.00000000e+00], + [ 3.00000000e+00]]), + Tensor(shape=[2, 1], dtype=Float32, + value=[[ 1.00000000e+00], + [ 4.00000000e+00]]), + Tensor(shape=[2, 1], dtype=Float32, + value=[[ 2.00000000e+00], + [ 5.00000000e+00]])) + """ + return tensor_split(x, indices_or_sections, 1) + + +def dsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors along the 3rd axis. + It is equivalent to `ops.tensor_split` with :math:`axis=2` . + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): See argument in :func:`mindspore.ops.tensor_split`. + + Returns: + A list of sub-tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input_x = np.arange(6).reshape((1, 2, 3)).astype('float32') + >>> output = ops.dsplit(Tensor(input_x), 3) + >>> print(output) + (Tensor(shape=[1, 2, 1], dtype=Float32, + value=[[[ 0.00000000e+00], + [ 3.00000000e+00]]]), + Tensor(shape=[1, 2, 1], dtype=Float32, + value=[[[ 1.00000000e+00], + [ 4.00000000e+00]]]), + Tensor(shape=[1, 2, 1], dtype=Float32, + value=[[[ 2.00000000e+00], + [ 5.00000000e+00]]])) + """ + return tensor_split(x, indices_or_sections, 2) + + def max(x, axis=0, keep_dims=False): """ Calculates the maximum value with the corresponding index. @@ -5278,6 +5547,10 @@ __all__ = [ 'broadcast_to', 'col2im', 'split', + 'tensor_split', + 'vsplit', + 'hsplit', + 'dsplit', "index_fill", 'max', 'argmax', diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 05ab6703a3e..1360bee34ac 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -191,6 +191,10 @@ tensor_operator_registry.register('tile', P.Tile) tensor_operator_registry.register('logit', logit) tensor_operator_registry.register('sum', P.ReduceSum) tensor_operator_registry.register('split', P.Split) +tensor_operator_registry.register('tensor_split', tensor_split) +tensor_operator_registry.register('vsplit', vsplit) +tensor_operator_registry.register('hsplit', hsplit) +tensor_operator_registry.register('dsplit', dsplit) tensor_operator_registry.register('select', P.Select) tensor_operator_registry.register('zeros_like', P.ZerosLike) tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor) @@ -216,7 +220,6 @@ tensor_operator_registry.register('unique_with_pad', P.UniqueWithPad) tensor_operator_registry.register('inplace_update', P.InplaceUpdate) tensor_operator_registry.register('col2im', col2im) tensor_operator_registry.register('standard_laplace', P.StandardLaplace) -tensor_operator_registry.register('split', P.Split) tensor_operator_registry.register('erf', P.Erf) tensor_operator_registry.register('erfc', P.Erfc) tensor_operator_registry.register('standard_normal', P.StandardNormal) diff --git a/tests/st/ops/test_split.py b/tests/st/ops/test_split.py new file mode 100644 index 00000000000..09afed58c98 --- /dev/null +++ b/tests/st/ops/test_split.py @@ -0,0 +1,245 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class SplitNet(nn.Cell): + def construct(self, x, indices_or_sections, axis=0): + out = ops.tensor_split(x, indices_or_sections, axis) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_tensor_split_int(mode): + """ + Feature: tensor_split + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = SplitNet() + a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_tensor_split_list(mode): + """ + Feature: tensor_split + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = SplitNet() + a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +class VSplitNet(nn.Cell): + def construct(self, x, indices_or_sections): + out = ops.vsplit(x, indices_or_sections) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_vsplit_int(mode): + """ + Feature: vsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = VSplitNet() + a = np.arange(20).reshape((10, 2)) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=0) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_vsplit_list(mode): + """ + Feature: vsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = VSplitNet() + a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=0) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +class HSplitNet(nn.Cell): + def construct(self, x, indices_or_sections): + out = ops.hsplit(x, indices_or_sections) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_hsplit_int(mode): + """ + Feature: hsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = HSplitNet() + a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=1) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_hsplit_list(mode): + """ + Feature: hsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = HSplitNet() + a = np.array(np.arange(10).reshape((2, 5)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=1) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +class DSplitNet(nn.Cell): + def construct(self, x, indices_or_sections): + out = ops.dsplit(x, indices_or_sections) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_dsplit_int(mode): + """ + Feature: dsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = DSplitNet() + a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=2) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_dsplit_list(mode): + """ + Feature: dsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = DSplitNet() + a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=2) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) diff --git a/tests/st/tensor/test_split.py b/tests/st/tensor/test_split.py new file mode 100644 index 00000000000..abd72b243cd --- /dev/null +++ b/tests/st/tensor/test_split.py @@ -0,0 +1,244 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn + + +class SplitNet(nn.Cell): + def construct(self, x, indices_or_sections, axis=0): + out = x.tensor_split(indices_or_sections, axis) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_tensor_split_int(mode): + """ + Feature: tensor_split + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = SplitNet() + a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_tensor_split_list(mode): + """ + Feature: tensor_split + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = SplitNet() + a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +class VSplitNet(nn.Cell): + def construct(self, x, indices_or_sections): + out = x.vsplit(indices_or_sections) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_vsplit_int(mode): + """ + Feature: vsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = VSplitNet() + a = np.arange(20).reshape((10, 2)) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=0) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_vsplit_list(mode): + """ + Feature: vsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = VSplitNet() + a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=0) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +class HSplitNet(nn.Cell): + def construct(self, x, indices_or_sections): + out = x.hsplit(indices_or_sections) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_hsplit_int(mode): + """ + Feature: hsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = HSplitNet() + a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=1) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_hsplit_list(mode): + """ + Feature: hsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = HSplitNet() + a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=1) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +class DSplitNet(nn.Cell): + def construct(self, x, indices_or_sections): + out = x.dsplit(indices_or_sections) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_dsplit_int(mode): + """ + Feature: dsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. + Expectation: success + """ + ms.set_context(mode=mode) + net = DSplitNet() + a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = 3 + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=2) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_f_dsplit_list(mode): + """ + Feature: dsplit + Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). + Expectation: success + """ + ms.set_context(mode=mode) + net = DSplitNet() + a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32) + x = ms.Tensor(a, dtype=ms.float32) + indices_or_sections = [2, 4] + out = net(x, indices_or_sections) + expect = np.array_split(a, indices_or_sections, axis=2) + for res, exp in zip(out, expect): + assert np.allclose(res.asnumpy(), exp)