add function and tensor api of unsqueeze
This commit is contained in:
parent
c451641481
commit
c167b4aa92
|
@ -389,6 +389,7 @@ Array操作
|
|||
mindspore.ops.unsorted_segment_min
|
||||
mindspore.ops.unsorted_segment_prod
|
||||
mindspore.ops.unsorted_segment_sum
|
||||
mindspore.ops.unsqueeze
|
||||
mindspore.ops.unstack
|
||||
|
||||
类型转换
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.unsqueeze
|
||||
============================
|
||||
|
||||
.. py:method:: mindspore.Tensor.unsqueeze(dim)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.unsqueeze`。
|
|
@ -200,6 +200,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.unsorted_segment_max
|
||||
mindspore.Tensor.unsorted_segment_min
|
||||
mindspore.Tensor.unsorted_segment_prod
|
||||
mindspore.Tensor.unsqueeze
|
||||
mindspore.Tensor.var
|
||||
mindspore.Tensor.view
|
||||
mindspore.Tensor.xdivy
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.ops.unsqueeze
|
||||
=========================
|
||||
|
||||
.. py:function:: mindspore.ops.unsqueeze(input_x, dim)
|
||||
|
||||
对输入 `input_x` 在给定维上添加额外维度。
|
||||
|
||||
扩展后的Tensor中位置对应 `dim` 的维度为插入的新维度。
|
||||
|
||||
.. note::
|
||||
如果指定的 `dim` 是负数,那么它会从后往前,从1开始计算index。
|
||||
|
||||
参数:
|
||||
- **input_x** (Tensor) - 输入Tensor,shape为 :math:`(x_1, x_2, ..., x_R)`。
|
||||
- **dim** (int) - 新插入的维度的位置。 `dim` 的值必须在范围 `[-input_x.ndim-1, input_x.ndim]` 内。仅接受常量输入。
|
||||
|
||||
返回:
|
||||
Tensor,维度在指定轴扩展之后的Tensor,与 `input_x` 的数据类型相同。如果 `dim` 是0,那么它的shape为 :math:`(1, x_1, x_2, ..., x_R)`。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `dim` 不是int。
|
||||
- **ValueError** - 如果 `dim` 超出了 :math:`[-input_x.ndim-1, input_x.ndim]` 的范围。
|
|
@ -206,6 +206,7 @@
|
|||
mindspore.Tensor.unsorted_segment_max
|
||||
mindspore.Tensor.unsorted_segment_min
|
||||
mindspore.Tensor.unsorted_segment_prod
|
||||
mindspore.Tensor.unsqueeze
|
||||
mindspore.Tensor.var
|
||||
mindspore.Tensor.view
|
||||
mindspore.Tensor.xdivy
|
||||
|
|
|
@ -389,6 +389,7 @@ Array Operation
|
|||
mindspore.ops.unsorted_segment_min
|
||||
mindspore.ops.unsorted_segment_prod
|
||||
mindspore.ops.unsorted_segment_sum
|
||||
mindspore.ops.unsqueeze
|
||||
mindspore.ops.unstack
|
||||
|
||||
Type Conversion
|
||||
|
|
|
@ -222,6 +222,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"unbind", std::string("unbind")}, // P.Unstack()
|
||||
{"unsqueeze", std::string("unsqueeze")}, // P.expand_dims()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"median", std::string("median")}, // P.median()
|
||||
{"cumsum", std::string("cumsum")}, // P.cumsum()
|
||||
|
|
|
@ -2358,12 +2358,20 @@ def broadcast_to(x, shape):
|
|||
|
||||
def expand_dims(x, axis):
|
||||
"""
|
||||
Insert a dimension of shape 1 at the specified axis of Tensor
|
||||
Insert a dimension of shape 1 at the specified axis of Tensor.
|
||||
"""
|
||||
check_is_int(axis, 'axis')
|
||||
return P.ExpandDims()(x, axis)
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Insert a dimension of shape 1 at the specified axis of Tensor.
|
||||
"""
|
||||
check_is_int(dim, 'dim')
|
||||
return P.ExpandDims()(x, dim)
|
||||
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
"""
|
||||
Fills elements of Tensor with value where mask is True.
|
||||
|
|
|
@ -1746,6 +1746,15 @@ class Tensor(Tensor_):
|
|||
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.unsqueeze`.
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_is_int(dim, 'dim')
|
||||
validator.check_int_range(dim, -self.ndim - 1, self.ndim + 1, Rel.INC_LEFT, 'dim')
|
||||
return tensor_operator_registry.get('unsqueeze')(self, dim)
|
||||
|
||||
def expand_dims(self, axis):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.expand_dims`.
|
||||
|
|
|
@ -59,6 +59,7 @@ from .array_func import (
|
|||
tuple_to_array,
|
||||
expand_dims,
|
||||
squeeze,
|
||||
unsqueeze,
|
||||
transpose,
|
||||
scatter_nd,
|
||||
scatter_nd_add,
|
||||
|
|
|
@ -1520,6 +1520,41 @@ def expand_dims(input_x, axis):
|
|||
return expand_dims_(input_x, axis)
|
||||
|
||||
|
||||
def unsqueeze(input_x, dim):
|
||||
"""
|
||||
Adds an additional dimension to `input_x` at the given dim.
|
||||
|
||||
Note:
|
||||
If the specified dim is a negative number, the index is counted
|
||||
backward from the end and starts at 1.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
dim (int): Specifies the dimension index at which to expand
|
||||
the shape of `input_x`. The value of `dim` must be in the range
|
||||
`[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
|
||||
value of `dim` is 0. It has the same data type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `dim` is not an int.
|
||||
ValueError: If `dim` is not in the valid range :math:`[-input_x.ndim-1, input_x.ndim]`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
|
||||
>>> output = ops.unsqueeze(input_tensor, dim=0)
|
||||
>>> print(output)
|
||||
[[[2. 2.]
|
||||
[2. 2.]]]
|
||||
"""
|
||||
return expand_dims_(input_x, dim)
|
||||
|
||||
|
||||
def squeeze(input_x, axis=()):
|
||||
"""
|
||||
Return the Tensor after deleting the dimension of size 1 in the specified `axis`.
|
||||
|
@ -4928,6 +4963,7 @@ __all__ = [
|
|||
'tuple_to_array',
|
||||
'expand_dims',
|
||||
'squeeze',
|
||||
'unsqueeze',
|
||||
'transpose',
|
||||
'scatter_nd',
|
||||
'scatter_nd_add',
|
||||
|
|
|
@ -441,6 +441,7 @@ tensor_operator_registry.register('gt', P.Greater)
|
|||
tensor_operator_registry.register('ge', P.GreaterEqual)
|
||||
tensor_operator_registry.register('shape', shape)
|
||||
tensor_operator_registry.register('squeeze', squeeze)
|
||||
tensor_operator_registry.register('unsqueeze', unsqueeze)
|
||||
tensor_operator_registry.register('expand_dims', expand_dims)
|
||||
# support GE backend for no compare operators
|
||||
tensor_operator_registry.register('cast', cast)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, dim):
|
||||
out = ops.unsqueeze(x, dim)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_unsqueeze_normal(mode):
|
||||
"""
|
||||
Feature: unqueeze
|
||||
Description: Verify the result of unqueeze
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.arange(2 * 3).reshape((2, 3)), dtype=ms.float32)
|
||||
out = net(x, dim=0)
|
||||
expect_out = np.array([np.arange(2 * 3).reshape((2, 3))])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
Loading…
Reference in New Issue