!44858 Tensor api unbind

Merge pull request !44858 from ZhidanLiu/unbind_master
This commit is contained in:
i-robot 2022-11-01 09:00:45 +00:00 committed by Gitee
commit 5f603ffb39
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 242 additions and 0 deletions

View File

@ -380,6 +380,7 @@ Array操作
mindspore.ops.tile
mindspore.ops.top_k
mindspore.ops.transpose
mindspore.ops.unbind
mindspore.ops.unfold
mindspore.ops.unique
mindspore.ops.unique_consecutive

View File

@ -0,0 +1,19 @@
mindspore.Tensor.unbind
========================
.. py:method:: mindspore.Tensor.unbind(dim=0)
根据指定轴对输入矩阵进行分解。
若输入Tensor在指定的轴上的rank为 `R` 则输出Tensor的rank为 `(R-1)`
给定一个shape为 :math:`(x_1, x_2, ..., x_R)` 的Tensor。如果存在 :math:`0 \le axis` 则输出Tensor的shape为 :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`
参数:
- **dim** (int) - 指定矩阵分解的轴。取值范围为[-R, R)默认值0。
返回:
Tensor对象组成的tuple。每个Tensor对象的shape相同。
异常:
- **ValueError** - `dim` 超出[-R, R)范围。

View File

@ -193,6 +193,7 @@ mindspore.Tensor
mindspore.Tensor.transpose
mindspore.Tensor.triu
mindspore.Tensor.true_divide
mindspore.Tensor.unbind
mindspore.Tensor.unfold
mindspore.Tensor.unique_consecutive
mindspore.Tensor.unique_with_pad

View File

@ -0,0 +1,20 @@
mindspore.ops.unbind
========================
.. py:function:: mindspore.ops.unbind(x, dim=0)
根据指定轴对输入矩阵进行分解。
若输入Tensor在指定的轴上的rank为 `R` 则输出Tensor的rank为 `(R-1)`
给定一个shape为 :math:`(x_1, x_2, ..., x_R)` 的Tensor。如果存在 :math:`0 \le axis` 则输出Tensor的shape为 :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`
参数:
- **x** (Tensor) - 输入Tensor其shape为 :math:`(x_1, x_2, ..., x_R)` 。rank必须大于0。
- **dim** (int) - 指定矩阵分解的轴。取值范围为[-R, R)默认值0。
返回:
Tensor对象组成的tuple。每个Tensor对象的shape相同。
异常:
- **ValueError** - `dim` 超出[-R, R)范围。

View File

@ -199,6 +199,7 @@
mindspore.Tensor.transpose
mindspore.Tensor.triu
mindspore.Tensor.true_divide
mindspore.Tensor.unbind
mindspore.Tensor.unfold
mindspore.Tensor.unique_consecutive
mindspore.Tensor.unique_with_pad

View File

@ -380,6 +380,7 @@ Array Operation
mindspore.ops.tile
mindspore.ops.top_k
mindspore.ops.transpose
mindspore.ops.unbind
mindspore.ops.unfold
mindspore.ops.unique
mindspore.ops.unique_consecutive

View File

@ -221,6 +221,7 @@ BuiltInTypeMap &GetMethodMap() {
{"nonzero", std::string("nonzero")}, // nonzero()
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"unbind", std::string("unbind")}, // P.Unstack()
{"astype", std::string("astype")}, // P.cast()
{"median", std::string("median")}, // P.median()
{"cumsum", std::string("cumsum")}, // P.cumsum()

View File

@ -704,6 +704,39 @@ def squeeze(x, axis=None):
return F.reshape(x, new_shape)
def unbind(x, dim=0):
r"""
Removes a tensor dimension in specified axis.
Unstacks a tensor of rank `R` along axis dimension, and output tensors will have rank `(R-1)`.
Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
Args:
x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`.
A tensor to be unstacked and the rank of the tensor must be greater than 0.
dim (int): Dimension along which to unpack. Negative values wrap around. The range is [-R, R). Default: 0.
Returns:
A tuple of tensors, the shape of each objects is the same.
Raises:
ValueError: If axis is out of the range [-R, R).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
>>> output = x.unbind()
>>> print(output)
(Tensor(shape=[3], dtype=Int64, value=[1, 2, 3]), Tensor(shape=[3], dtype=Int64, value=[4, 5, 6]),
Tensor(shape=[3], dtype=Int64, value=[7, 8, 9]))
"""
return P.Unstack(axis=dim)(x)
def argmax(x, axis=None):
"""
Returns the indices of the maximum values along an axis.

View File

@ -3321,6 +3321,37 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('dense_to_sparse_csr')(self)
def unbind(self, dim=0):
r"""
Removes a tensor dimension in specified axis.
Unstack a tensor of rank `R` along axis dimension, and output tensors will have rank `(R-1)`.
Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
Args:
dim (int): Dimension along which to unpack. Negative values wrap around. The range is [-R, R). Default: 0.
Returns:
A tuple of tensors, the shape of each objects is the same.
Raises:
ValueError: If `dim` is out of the range [-R, R).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
>>> output = x.unbind()
>>> print(output)
(Tensor(shape=[3], dtype=Int64, value=[1, 2, 3]), Tensor(shape=[3], dtype=Int64, value=[4, 5, 6]),
Tensor(shape=[3], dtype=Int64, value=[7, 8, 9]))
"""
self._init_check()
return tensor_operator_registry.get('unbind')(dim)(self)
def unsorted_segment_min(self, segment_ids, num_segments):
r"""
For details, please refer to :func:`mindspore.ops.unsorted_segment_min`.

View File

@ -50,6 +50,7 @@ from .array_func import (
flatten,
concat,
stack,
unbind,
unstack,
tensor_slice,
strided_slice,

View File

@ -1451,6 +1451,40 @@ def unstack(input_x, axis=0):
return _unstack(input_x)
def unbind(x, dim=0):
r"""
Removes a tensor dimension in specified axis.
Unstacks a tensor of rank `R` along axis dimension, and output tensors will have rank `(R-1)`.
Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
Args:
x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`.
A tensor to be unstacked and the rank of the tensor must be greater than 0.
dim (int): Dimension along which to unpack. Negative values wrap around. The range is [-R, R). Default: 0.
Returns:
A tuple of tensors, the shape of each objects is the same.
Raises:
ValueError: If axis is out of the range [-R, R).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
>>> output = ops.unbind(x, dim=0)
>>> print(output)
(Tensor(shape=[3], dtype=Int64, value=[1, 2, 3]), Tensor(shape=[3], dtype=Int64, value=[4, 5, 6]),
Tensor(shape=[3], dtype=Int64, value=[7, 8, 9]))
"""
_unstack = _get_cache_prim(P.Unstack)(dim)
return _unstack(x)
def expand_dims(input_x, axis):
"""
Adds an additional dimension to `input_x` at the given axis.
@ -4885,6 +4919,7 @@ __all__ = [
'slice',
'concat',
'stack',
'unbind',
'unstack',
'scalar_cast',
'scalar_to_tensor',

View File

@ -458,6 +458,7 @@ tensor_operator_registry.register('gather_elements', gather_elements)
tensor_operator_registry.register('gather_nd', gather_nd)
tensor_operator_registry.register('stack', stack)
tensor_operator_registry.register('unstack', unstack)
tensor_operator_registry.register('unbind', P.Unstack)
tensor_operator_registry.register('log', log)
tensor_operator_registry.register('lerp', lerp)
tensor_operator_registry.register('floor', floor)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops unbind """
import numpy as np
import pytest
from mindspore import context, Tensor
import mindspore.ops as ops
from mindspore import nn
class UnbindNet(nn.Cell):
def construct(self, x, dim):
return ops.unbind(x, dim)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_unbind(mode):
"""
Feature: unbind
Description: Verify the result of unbind
Expectation: success
"""
context.set_context(mode=mode)
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
dim = 0
unbind = UnbindNet()
output = unbind(x, dim)
for i in range(len(x)):
assert np.allclose(output[i].asnumpy(), x[i].asnumpy(), rtol=0.0001)

View File

@ -0,0 +1,48 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_unbind """
import numpy as np
import pytest
from mindspore import context, Tensor
from mindspore import nn
class UnbindNet(nn.Cell):
def construct(self, x, dim):
return x.unbind(dim)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_unbind(mode):
"""
Feature: unbind
Description: Verify the result of unbind
Expectation: success
"""
context.set_context(mode=mode)
x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
dim = 0
unbind = UnbindNet()
output = unbind(x, dim)
for i in range(len(x)):
assert np.allclose(output[i].asnumpy(), x[i].asnumpy(), rtol=0.0001)