add split series api
This commit is contained in:
parent
4118dbcd4d
commit
6875db69f0
|
@ -372,6 +372,7 @@ Array操作
|
|||
mindspore.ops.diag
|
||||
mindspore.ops.diagonal
|
||||
mindspore.ops.dyn_shape
|
||||
mindspore.ops.dsplit
|
||||
mindspore.ops.expand
|
||||
mindspore.ops.expand_dims
|
||||
mindspore.ops.flip
|
||||
|
@ -382,6 +383,7 @@ Array操作
|
|||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
mindspore.ops.gather_nd
|
||||
mindspore.ops.hsplit
|
||||
mindspore.ops.index_add
|
||||
mindspore.ops.index_fill
|
||||
mindspore.ops.inplace_add
|
||||
|
@ -428,6 +430,7 @@ Array操作
|
|||
mindspore.ops.tensor_scatter_mul
|
||||
mindspore.ops.tensor_scatter_sub
|
||||
mindspore.ops.tensor_scatter_elements
|
||||
mindspore.ops.tensor_split
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
|
@ -442,6 +445,7 @@ Array操作
|
|||
mindspore.ops.unsorted_segment_sum
|
||||
mindspore.ops.unsqueeze
|
||||
mindspore.ops.unstack
|
||||
mindspore.ops.vsplit
|
||||
mindspore.ops.where
|
||||
|
||||
类型转换
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.dsplit
|
||||
========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.dsplit(indices_or_sections)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.dsplit`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.hsplit
|
||||
========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.hsplit(indices_or_sections)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.hsplit`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.tensor_split
|
||||
==============================
|
||||
|
||||
.. py:method:: mindspore.Tensor.tensor_split(indices_or_sections, axis=0)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.tensor_split`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.vsplit
|
||||
========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.vsplit(indices_or_sections)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.vsplit`。
|
|
@ -85,6 +85,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.diagonal
|
||||
mindspore.Tensor.div
|
||||
mindspore.Tensor.divide
|
||||
mindspore.Tensor.dsplit
|
||||
mindspore.Tensor.dtype
|
||||
mindspore.Tensor.equal
|
||||
mindspore.Tensor.erf
|
||||
|
@ -118,6 +119,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.hardshrink
|
||||
mindspore.Tensor.has_init
|
||||
mindspore.Tensor.heaviside
|
||||
mindspore.Tensor.hsplit
|
||||
mindspore.Tensor.hypot
|
||||
mindspore.Tensor.i0
|
||||
mindspore.Tensor.igamma
|
||||
|
@ -230,6 +232,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.take
|
||||
mindspore.Tensor.tan
|
||||
mindspore.Tensor.tanh
|
||||
mindspore.Tensor.tensor_split
|
||||
mindspore.Tensor.tile
|
||||
mindspore.Tensor.to
|
||||
mindspore.Tensor.to_coo
|
||||
|
@ -250,6 +253,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.unsqueeze
|
||||
mindspore.Tensor.var
|
||||
mindspore.Tensor.view
|
||||
mindspore.Tensor.vsplit
|
||||
mindspore.Tensor.where
|
||||
mindspore.Tensor.xdivy
|
||||
mindspore.Tensor.xlogy
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
mindspore.ops.dsplit
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.dsplit(x, indices_or_sections)
|
||||
|
||||
沿着第三轴将输入Tensor分割成多个子Tensor。等同于 `axis=2` 时的 `ops.tensor_split` 。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 待分割的Tensor。
|
||||
- **indices_or_sections** (Union[int, tuple(int), list(int)]) - 参考 :func:`mindspore.ops.tensor_split`.
|
||||
|
||||
返回:
|
||||
tuple[Tensor]。
|
|
@ -0,0 +1,13 @@
|
|||
mindspore.ops.hsplit
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.hsplit(x, indices_or_sections)
|
||||
|
||||
水平地将输入Tensor分割成多个子Tensor。等同于 `axis=1` 时的 `ops.tensor_split` 。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 待分割的Tensor。
|
||||
- **indices_or_sections** (Union[int, tuple(int), list(int)]) - 参考 :func:`mindspore.ops.tensor_split`.
|
||||
|
||||
返回:
|
||||
tuple[Tensor]。
|
|
@ -0,0 +1,26 @@
|
|||
mindspore.ops.tensor_split
|
||||
===========================
|
||||
|
||||
.. py:function:: mindspore.ops.tensor_split(x, indices_or_sections, axis=0)
|
||||
|
||||
根据指定的轴将输入Tensor进行分割成多个子tensor。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 待分割的Tensor。
|
||||
- **indices_or_sections** (Union[int, tuple(int), list(int)]) -
|
||||
如果`indices_or_sections`是整数类型n,输入将沿`axis`轴分割成n份。如果输入沿着`axis`轴能被n整除,那么每个切片的大小相同为 :math:`input.size(axis) / n` 。如果不能被n整除,那么前 :math:`input.size(axis) % n` 个切片的大小为 :math:`input.size(axis) // n + 1` ,其余切片的大小为 :math:`input.size(axis) // n` 。
|
||||
如果`indices_or_sections`是由int组成list或者tuple,那么输入将沿着`axis`轴在tuple或list中的索引处切分。例如::math:`indices_or_sections=[2, 3]` 和 :math:`axis=0` 将得到切片 :math:`x[:2]` , :math:`x[2:3]` ,和 :math:`x[3:]` .
|
||||
- **axis** (int) - 指定分割轴。默认值:0。
|
||||
|
||||
返回:
|
||||
tuple[Tensor]。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `axis` 不是int类型。
|
||||
- **TypeError** - `axis` 不是int类型。
|
||||
- **ValueError** - 参数 `axis` 超出 :math:`-x.dim, x.dim)` 范围。
|
||||
- **TypeError** - `indices_or_sections` 中的每个元素不是int类型
|
||||
- **TypeError** - `indices_or_sections` 不是int,tuple(int)或list(int)。
|
||||
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
mindspore.ops.vsplit
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.vsplit(x, indices_or_sections)
|
||||
|
||||
垂直地将输入Tensor分割成多个子Tensor。等同于 `axis=0` 时的 `ops.tensor_split` 。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 待分割的Tensor。
|
||||
- **indices_or_sections** (Union[int, tuple(int), list(int)]) - 参考 :func:`mindspore.ops.tensor_split`.
|
||||
|
||||
返回:
|
||||
tuple[Tensor]。
|
|
@ -91,6 +91,7 @@
|
|||
mindspore.Tensor.diagonal
|
||||
mindspore.Tensor.div
|
||||
mindspore.Tensor.divide
|
||||
mindspore.Tensor.dsplit
|
||||
mindspore.Tensor.dtype
|
||||
mindspore.Tensor.equal
|
||||
mindspore.Tensor.erf
|
||||
|
@ -124,6 +125,7 @@
|
|||
mindspore.Tensor.hardshrink
|
||||
mindspore.Tensor.has_init
|
||||
mindspore.Tensor.heaviside
|
||||
mindspore.Tensor.hsplit
|
||||
mindspore.Tensor.hypot
|
||||
mindspore.Tensor.i0
|
||||
mindspore.Tensor.igamma
|
||||
|
@ -236,6 +238,7 @@
|
|||
mindspore.Tensor.take
|
||||
mindspore.Tensor.tan
|
||||
mindspore.Tensor.tanh
|
||||
mindspore.Tensor.tensor_split
|
||||
mindspore.Tensor.tile
|
||||
mindspore.Tensor.to
|
||||
mindspore.Tensor.to_coo
|
||||
|
@ -259,6 +262,7 @@
|
|||
mindspore.Tensor.where
|
||||
mindspore.Tensor.xdivy
|
||||
mindspore.Tensor.xlogy
|
||||
mindspore.Tensor.vsplit
|
||||
|
||||
{% elif fullname=="mindspore.nn.Cell" %}
|
||||
{{ fullname | underline }}
|
||||
|
|
|
@ -371,6 +371,7 @@ Array Operation
|
|||
mindspore.ops.count_nonzero
|
||||
mindspore.ops.diag
|
||||
mindspore.ops.diagonal
|
||||
mindspore.ops.dsplit
|
||||
mindspore.ops.dyn_shape
|
||||
mindspore.ops.expand
|
||||
mindspore.ops.expand_dims
|
||||
|
@ -382,6 +383,7 @@ Array Operation
|
|||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
mindspore.ops.gather_nd
|
||||
mindspore.ops.hsplit
|
||||
mindspore.ops.index_add
|
||||
mindspore.ops.index_fill
|
||||
mindspore.ops.inplace_add
|
||||
|
@ -428,6 +430,7 @@ Array Operation
|
|||
mindspore.ops.tensor_scatter_mul
|
||||
mindspore.ops.tensor_scatter_sub
|
||||
mindspore.ops.tensor_scatter_elements
|
||||
mindspore.ops.tensor_split
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
|
@ -442,6 +445,7 @@ Array Operation
|
|||
mindspore.ops.unsorted_segment_sum
|
||||
mindspore.ops.unsqueeze
|
||||
mindspore.ops.unstack
|
||||
mindspore.ops.vsplit
|
||||
mindspore.ops.where
|
||||
|
||||
Type Conversion
|
||||
|
|
|
@ -332,6 +332,10 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"to_csr", std::string("to_csr")}, // dense_to_sparse_csr()
|
||||
{"col2im", std::string("col2im")}, // P.Col2Im
|
||||
{"split", std::string("split")}, // P.Split()
|
||||
{"tensor_split", std::string("tensor_split")}, // tensor_split
|
||||
{"vsplit", std::string("vsplit")}, // vsplit
|
||||
{"hsplit", std::string("hsplit")}, // hsplit
|
||||
{"dsplit", std::string("dsplit")}, // dplit
|
||||
{"random_categorical", std::string("random_categorical")}, // P.RandomCategorical
|
||||
{"xlogy", std::string("xlogy")}, // P.Xlogy()
|
||||
{"erf", std::string("erf")}, // P.Erf()
|
||||
|
|
|
@ -3462,6 +3462,39 @@ def split(input_x, axis=0, output_num=1):
|
|||
return F.split(input_x, axis, output_num)
|
||||
|
||||
|
||||
def tensor_split(x, indices_or_sections, axis=0):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors along the given axis.
|
||||
Refer to :func:`mindspore.ops.tensor_split` for more detail.
|
||||
"""
|
||||
return F.tensor_split(x, indices_or_sections, axis=axis)
|
||||
|
||||
|
||||
def vsplit(x, indices_or_sections):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors vertically. It is equivalent to `ops.tensor_split` with :math:`axis=0` .
|
||||
Refer to :func:`mindspore.ops.vsplit` for more detail.
|
||||
"""
|
||||
return F.vsplit(x, indices_or_sections)
|
||||
|
||||
|
||||
def hsplit(x, indices_or_sections):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors horizontally. It is equivalent to `ops.tensor_split` with :math:`axis=1` .
|
||||
Refer to :func:`mindspore.ops.hsplit` for more detail.
|
||||
"""
|
||||
return F.hsplit(x, indices_or_sections)
|
||||
|
||||
|
||||
def dsplit(x, indices_or_sections):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors along the 3rd axis.
|
||||
It is equivalent to `ops.tensor_split` with :math:`axis=2` .
|
||||
Refer to :func:`mindspore.ops.tensor_split` for more detail.
|
||||
"""
|
||||
return F.dsplit(x, indices_or_sections)
|
||||
|
||||
|
||||
def xlogy(x, y):
|
||||
r"""
|
||||
Computes the first input tensor multiplied by the logarithm of second input tensor element-wise.
|
||||
|
|
|
@ -3541,6 +3541,34 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
return tensor_operator_registry.get('split')(axis, output_num)(self)
|
||||
|
||||
def tensor_split(self, indices_or_sections, axis=0):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.tensor_split`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('tensor_split')(self, indices_or_sections, axis)
|
||||
|
||||
def vsplit(self, indices_or_sections):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.vsplit`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('vsplit')(self, indices_or_sections)
|
||||
|
||||
def hsplit(self, indices_or_sections):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.hsplit`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('hsplit')(self, indices_or_sections)
|
||||
|
||||
def dsplit(self, indices_or_sections):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.dsplit`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('dsplit')(self, indices_or_sections)
|
||||
|
||||
def xlogy(self, y):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.xlogy`.
|
||||
|
|
|
@ -113,6 +113,10 @@ from .array_func import (
|
|||
unsorted_segment_sum,
|
||||
col2im,
|
||||
split,
|
||||
tensor_split,
|
||||
vsplit,
|
||||
hsplit,
|
||||
dsplit,
|
||||
index_fill,
|
||||
max,
|
||||
argmax,
|
||||
|
|
|
@ -4476,6 +4476,275 @@ def split(input_x, axis=0, output_num=1):
|
|||
return split_(input_x)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _canonicalize_axis(axis, ndim):
|
||||
"""
|
||||
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
|
||||
|
||||
Args:
|
||||
axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
|
||||
ndim (int): The number of dimensions of the tensor.
|
||||
Return:
|
||||
Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
|
||||
"""
|
||||
if isinstance(axis, int):
|
||||
axis = [axis]
|
||||
for ax in axis:
|
||||
if not isinstance(ax, int):
|
||||
raise TypeError(f'axis should be integers, not {type(ax)}')
|
||||
if not -ndim <= ax < ndim:
|
||||
raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}')
|
||||
|
||||
def canonicalizer(ax):
|
||||
return ax + ndim if ax < 0 else ax
|
||||
|
||||
axis = tuple([canonicalizer(ax) for ax in axis])
|
||||
if all(axis.count(el) <= 1 for el in axis):
|
||||
return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
|
||||
raise ValueError(f"duplicate axis in {axis}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _list_comprehensions(obj, item=None, return_tuple=False):
|
||||
"""
|
||||
Generates a new list or tuple by list comprehension.
|
||||
|
||||
Args:
|
||||
obj (Union[int, list, tuple]):
|
||||
If integer, it will be the length of the returned tuple/list.
|
||||
item: The value to be filled. Default: None.
|
||||
If None, the values in the new list/tuple are the same as obj
|
||||
or range(obj) when obj is integer.
|
||||
return_tuple(bool): If true, returns tuple, else returns list.
|
||||
|
||||
Returns:
|
||||
List or tuple.
|
||||
"""
|
||||
lst = obj
|
||||
if isinstance(obj, int):
|
||||
lst = np.arange(obj)
|
||||
if item is None:
|
||||
res = list(lst)
|
||||
else:
|
||||
res = [item for _ in lst]
|
||||
if return_tuple:
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
|
||||
@constexpr
|
||||
def _tuple_setitem(tup, idx, value):
|
||||
"""
|
||||
Returns a tuple with specified `idx` set to `value`.
|
||||
"""
|
||||
tup = list(tup)
|
||||
tup[idx] = value
|
||||
return tuple(tup)
|
||||
|
||||
|
||||
def _tensor_split_sub_tensors(x, indices_or_sections, axis):
|
||||
"""
|
||||
Splits the input tensor `x` into multiple sub-tensors along the axis according to the given `indices_or_sections`
|
||||
with type of tuple or list.
|
||||
"""
|
||||
length_along_dim = x.shape[axis]
|
||||
indices_or_sections = tuple(indices_or_sections)
|
||||
indices_or_sections += (length_along_dim,)
|
||||
sub_tensors = []
|
||||
strides = _list_comprehensions(x.ndim, 1, True)
|
||||
begin = _list_comprehensions(x.ndim, 0)
|
||||
end = _list_comprehensions(x.shape)
|
||||
for i, idx in enumerate(indices_or_sections):
|
||||
begin[axis] = 0 if i == 0 else indices_or_sections[i - 1]
|
||||
end[axis] = idx
|
||||
sliced_tensor = strided_slice(x, tuple(begin), tuple(end), strides)
|
||||
sub_tensors.append(sliced_tensor)
|
||||
return sub_tensors
|
||||
|
||||
|
||||
def _tensor_split_sub_int(x, indices_or_sections, axis):
|
||||
"""
|
||||
Splits the input tensor `x` into multiple sub-tensors along the axis according to the given `indices_or_sections`
|
||||
with type if int.
|
||||
"""
|
||||
arr_shape = x.shape
|
||||
length_along_dim = arr_shape[axis]
|
||||
if indices_or_sections > length_along_dim:
|
||||
res = P.Split(axis, length_along_dim)(x)
|
||||
indices_or_sections_n = [i for i in np.arange(length_along_dim, indices_or_sections)]
|
||||
res2 = _tensor_split_sub_tensors(x, indices_or_sections_n, axis)
|
||||
res += tuple(res2)[1:]
|
||||
elif length_along_dim % indices_or_sections == 0:
|
||||
res = P.Split(axis, indices_or_sections)(x)
|
||||
else:
|
||||
num_long_tensor = length_along_dim % indices_or_sections
|
||||
num_short_tensor = indices_or_sections - num_long_tensor
|
||||
length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1)
|
||||
length2 = length_along_dim - length1
|
||||
start1 = _list_comprehensions(rank(x), 0, True)
|
||||
size1 = _tuple_setitem(arr_shape, axis, length1)
|
||||
start2 = _tuple_setitem(start1, axis, length1)
|
||||
size2 = _tuple_setitem(arr_shape, axis, length2)
|
||||
res = P.Split(axis, num_long_tensor)(tensor_slice(x, start1, size1)) + \
|
||||
P.Split(axis, num_short_tensor)(tensor_slice(x, start2, size2))
|
||||
return res
|
||||
|
||||
|
||||
def tensor_split(x, indices_or_sections, axis=0):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors along the given axis.
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be divided.
|
||||
indices_or_sections (Union[int, tuple(int), list(int)]):
|
||||
If `indices_or_sections` is an integer n, input is split into
|
||||
n sections along dimension `axis`. If input is divisible by n along dimension `axis`, each section will be
|
||||
of equal size, :math:`input.size(axis) / n` . If input is not divisible by n, the sizes of the first
|
||||
:math:`input.size(axis) % n` sections will have size :math:`input.size(axis) // n + 1` , and the rest will
|
||||
have size :math:`input.size(axis) // n` .
|
||||
If `indices_or_sections` is a list or tuple of ints, then input is split
|
||||
along dimension `axis` at each of the indices in the list, tuple. For instance,
|
||||
:math:`indices_or_sections=[2, 3]` and :math:`axis=0` would result in the tensors :math:`x[:2]` ,
|
||||
:math:`x[2:3]` , and :math:`x[3:]` .
|
||||
axis (int): The axis along which to split. Default: 0.
|
||||
|
||||
Returns:
|
||||
A tuple of sub-tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument `x` is not Tensor.
|
||||
TypeError: If argument `axis` is not Tensor.
|
||||
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)` .
|
||||
TypeError: If each element in 'indices_or_sections' is not integer.
|
||||
TypeError: If argument `indices_or_sections` is not int, tuple(int) or list(int).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(9).astype("float32")
|
||||
>>> output = ops.tensor_split(Tensor(input_x), 3)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3], dtype=Float32,
|
||||
value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
|
||||
Tensor(shape=[3], dtype=Float32,
|
||||
value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
|
||||
Tensor(shape=[3], dtype=Float32,
|
||||
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
||||
"""
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f'expect `x` is a Tensor, but got {type(x)}')
|
||||
|
||||
_ = validator.check_axis_type(axis, True, False, False)
|
||||
axis = _canonicalize_axis(axis, x.ndim)
|
||||
if isinstance(indices_or_sections, int):
|
||||
res = _tensor_split_sub_int(x, indices_or_sections, axis)
|
||||
|
||||
elif isinstance(indices_or_sections, (list, tuple)):
|
||||
for item in indices_or_sections:
|
||||
if not isinstance(item, int):
|
||||
raise TypeError(f"Each element in 'indices_or_sections' should be integer, but got {type(item)}.")
|
||||
res = _tensor_split_sub_tensors(x, indices_or_sections, axis)
|
||||
else:
|
||||
raise TypeError(f"Type of Argument `indices_or_sections` should be integer, tuple(int) or list(int), " \
|
||||
f"but got {type(indices_or_sections)}")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def vsplit(x, indices_or_sections):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors vertically.
|
||||
It is equivalent to `ops.tensor_split` with :math:`axis=0` .
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be divided.
|
||||
indices_or_sections (Union[int, tuple(int), list(int)]): See argument in :func:`mindspore.ops.tensor_split`.
|
||||
|
||||
Returns:
|
||||
A list of sub-tensors.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(9).reshape((3, 3)).astype('float32')
|
||||
>>> output = ops.vsplit(Tensor(input_x), 3)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[1, 3], dtype=Float32,
|
||||
value=[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]]),
|
||||
Tensor(shape=[1, 3], dtype=Float32,
|
||||
value=[[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]),
|
||||
Tensor(shape=[1, 3], dtype=Float32,
|
||||
value=[[ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]]))
|
||||
"""
|
||||
return tensor_split(x, indices_or_sections, 0)
|
||||
|
||||
|
||||
def hsplit(x, indices_or_sections):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors horizontally.
|
||||
It is equivalent to `ops.tensor_split` with :math:`axis=1` .
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be divided.
|
||||
indices_or_sections (Union[int, tuple(int), list(int)]): See argument in :func:`mindspore.ops.tensor_split`.
|
||||
|
||||
Returns:
|
||||
A list of sub-tensors.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(6).reshape((2, 3)).astype('float32')
|
||||
>>> output = ops.hsplit(Tensor(input_x), 3)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[2, 1], dtype=Float32,
|
||||
value=[[ 0.00000000e+00],
|
||||
[ 3.00000000e+00]]),
|
||||
Tensor(shape=[2, 1], dtype=Float32,
|
||||
value=[[ 1.00000000e+00],
|
||||
[ 4.00000000e+00]]),
|
||||
Tensor(shape=[2, 1], dtype=Float32,
|
||||
value=[[ 2.00000000e+00],
|
||||
[ 5.00000000e+00]]))
|
||||
"""
|
||||
return tensor_split(x, indices_or_sections, 1)
|
||||
|
||||
|
||||
def dsplit(x, indices_or_sections):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors along the 3rd axis.
|
||||
It is equivalent to `ops.tensor_split` with :math:`axis=2` .
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be divided.
|
||||
indices_or_sections (Union[int, tuple(int), list(int)]): See argument in :func:`mindspore.ops.tensor_split`.
|
||||
|
||||
Returns:
|
||||
A list of sub-tensors.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(6).reshape((1, 2, 3)).astype('float32')
|
||||
>>> output = ops.dsplit(Tensor(input_x), 3)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[1, 2, 1], dtype=Float32,
|
||||
value=[[[ 0.00000000e+00],
|
||||
[ 3.00000000e+00]]]),
|
||||
Tensor(shape=[1, 2, 1], dtype=Float32,
|
||||
value=[[[ 1.00000000e+00],
|
||||
[ 4.00000000e+00]]]),
|
||||
Tensor(shape=[1, 2, 1], dtype=Float32,
|
||||
value=[[[ 2.00000000e+00],
|
||||
[ 5.00000000e+00]]]))
|
||||
"""
|
||||
return tensor_split(x, indices_or_sections, 2)
|
||||
|
||||
|
||||
def max(x, axis=0, keep_dims=False):
|
||||
"""
|
||||
Calculates the maximum value with the corresponding index.
|
||||
|
@ -5278,6 +5547,10 @@ __all__ = [
|
|||
'broadcast_to',
|
||||
'col2im',
|
||||
'split',
|
||||
'tensor_split',
|
||||
'vsplit',
|
||||
'hsplit',
|
||||
'dsplit',
|
||||
"index_fill",
|
||||
'max',
|
||||
'argmax',
|
||||
|
|
|
@ -191,6 +191,10 @@ tensor_operator_registry.register('tile', P.Tile)
|
|||
tensor_operator_registry.register('logit', logit)
|
||||
tensor_operator_registry.register('sum', P.ReduceSum)
|
||||
tensor_operator_registry.register('split', P.Split)
|
||||
tensor_operator_registry.register('tensor_split', tensor_split)
|
||||
tensor_operator_registry.register('vsplit', vsplit)
|
||||
tensor_operator_registry.register('hsplit', hsplit)
|
||||
tensor_operator_registry.register('dsplit', dsplit)
|
||||
tensor_operator_registry.register('select', P.Select)
|
||||
tensor_operator_registry.register('zeros_like', P.ZerosLike)
|
||||
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
|
||||
|
@ -216,7 +220,6 @@ tensor_operator_registry.register('unique_with_pad', P.UniqueWithPad)
|
|||
tensor_operator_registry.register('inplace_update', P.InplaceUpdate)
|
||||
tensor_operator_registry.register('col2im', col2im)
|
||||
tensor_operator_registry.register('standard_laplace', P.StandardLaplace)
|
||||
tensor_operator_registry.register('split', P.Split)
|
||||
tensor_operator_registry.register('erf', P.Erf)
|
||||
tensor_operator_registry.register('erfc', P.Erfc)
|
||||
tensor_operator_registry.register('standard_normal', P.StandardNormal)
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class SplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections, axis=0):
|
||||
out = ops.tensor_split(x, indices_or_sections, axis)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_tensor_split_int(mode):
|
||||
"""
|
||||
Feature: tensor_split
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = SplitNet()
|
||||
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_tensor_split_list(mode):
|
||||
"""
|
||||
Feature: tensor_split
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = SplitNet()
|
||||
a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
class VSplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections):
|
||||
out = ops.vsplit(x, indices_or_sections)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_vsplit_int(mode):
|
||||
"""
|
||||
Feature: vsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = VSplitNet()
|
||||
a = np.arange(20).reshape((10, 2))
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=0)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_vsplit_list(mode):
|
||||
"""
|
||||
Feature: vsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = VSplitNet()
|
||||
a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=0)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
class HSplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections):
|
||||
out = ops.hsplit(x, indices_or_sections)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_hsplit_int(mode):
|
||||
"""
|
||||
Feature: hsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = HSplitNet()
|
||||
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=1)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_hsplit_list(mode):
|
||||
"""
|
||||
Feature: hsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = HSplitNet()
|
||||
a = np.array(np.arange(10).reshape((2, 5)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=1)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
class DSplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections):
|
||||
out = ops.dsplit(x, indices_or_sections)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_dsplit_int(mode):
|
||||
"""
|
||||
Feature: dsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = DSplitNet()
|
||||
a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=2)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_dsplit_list(mode):
|
||||
"""
|
||||
Feature: dsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = DSplitNet()
|
||||
a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=2)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class SplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections, axis=0):
|
||||
out = x.tensor_split(indices_or_sections, axis)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_tensor_split_int(mode):
|
||||
"""
|
||||
Feature: tensor_split
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = SplitNet()
|
||||
a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_tensor_split_list(mode):
|
||||
"""
|
||||
Feature: tensor_split
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = SplitNet()
|
||||
a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
class VSplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections):
|
||||
out = x.vsplit(indices_or_sections)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_vsplit_int(mode):
|
||||
"""
|
||||
Feature: vsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = VSplitNet()
|
||||
a = np.arange(20).reshape((10, 2))
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=0)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_vsplit_list(mode):
|
||||
"""
|
||||
Feature: vsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = VSplitNet()
|
||||
a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=0)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
class HSplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections):
|
||||
out = x.hsplit(indices_or_sections)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_hsplit_int(mode):
|
||||
"""
|
||||
Feature: hsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = HSplitNet()
|
||||
a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=1)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_hsplit_list(mode):
|
||||
"""
|
||||
Feature: hsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = HSplitNet()
|
||||
a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=1)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
class DSplitNet(nn.Cell):
|
||||
def construct(self, x, indices_or_sections):
|
||||
out = x.dsplit(indices_or_sections)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_dsplit_int(mode):
|
||||
"""
|
||||
Feature: dsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is int.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = DSplitNet()
|
||||
a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = 3
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=2)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_f_dsplit_list(mode):
|
||||
"""
|
||||
Feature: dsplit
|
||||
Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int).
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = DSplitNet()
|
||||
a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32)
|
||||
x = ms.Tensor(a, dtype=ms.float32)
|
||||
indices_or_sections = [2, 4]
|
||||
out = net(x, indices_or_sections)
|
||||
expect = np.array_split(a, indices_or_sections, axis=2)
|
||||
for res, exp in zip(out, expect):
|
||||
assert np.allclose(res.asnumpy(), exp)
|
Loading…
Reference in New Issue