unique support multi-dim tensor

This commit is contained in:
jiangzhenguang 2022-02-07 19:58:25 +08:00
parent 438a4081fb
commit 28f89f36b4
8 changed files with 251 additions and 4 deletions

View File

@ -5,6 +5,10 @@ mindspore.nn.ResizeBilinear
使用双线性插值调整输入Tensor为指定的大小。
**参数:**
- **half_pixel_centers** (bool) - 是否几何中心对齐。如果设置为True, 那么`scale_factor`应该设置为False。默认值False。
**输入:**
- **x** (Tensor) - ResizeBilinear的输入四维的Tensor其shape为 :math:`(batch, channels, height, width)` 数据类型为float16或float32。

View File

@ -32,8 +32,8 @@ from mindspore._checkparam import Rel, Validator
from ..cell import Cell
from .activation import get_activation
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
'Tril', 'Triu', 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll']
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll']
class L1Regularizer(Cell):
@ -851,6 +851,10 @@ class ResizeBilinear(Cell):
r"""
Samples the input tensor to the given size or scale_factor by using bilinear interpolate.
Args:
half_pixel_centers (bool): Whether half pixel center. If set to True, `align_corners` should be False.
Default: False.
Inputs:
- **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape
:math:`(batch, channels, height, width)`, with data type of float16 or float32.
@ -862,8 +866,6 @@ class ResizeBilinear(Cell):
- **align_corners** (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, which exactly
aligns the 4 corners of images and resized images. If false, rescale by :math:`new\_height / height`.
Default: False.
- **half_pixel_centers** (bool): Whether half pixel center. If set to True, `align_corners` should be False.
Default: False.
Outputs:
Resized tensor.

View File

@ -0,0 +1,14 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,100 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for function."""
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
@constexpr
def get_x_shape(x_shape):
s = 1
for i in x_shape:
s = s * i
return (s,)
def unique(x):
"""
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
tensor corresponding to the output unique tensor.
The output contains Tensor `y` and Tensor `idx`, the format is probably similar to (`y`, `idx`).
The shape of Tensor `y` and Tensor `idx` is different in most cases, because Tensor `y` will be deduplicated,
and the shape of Tensor `idx` is consistent with the input.
To get the same shape between `idx` and `y`, please ref to :class:'mindspore.ops.UniqueWithPad' operator.
.. warning::
This module is in beta.
Args:
x (Tensor): The input tensor.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Returns:
Tuple, containing Tensor objects `(y, idx), `y` is a tensor with the
same type as `input_x`, and contains the unique elements in `x`, sorted in
ascending order. `idx` is a tensor containing indices of elements in
the input corresponding to the output tensor, have the same shape with `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, nn
>>> from mindspore import ops
>>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
>>> output = ops.unique(input_x)
>>> print(output)
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
>>> y = output[0]
>>> print(y)
[1 2 5]
>>> idx = output[1]
>>> print(idx)
[0 1 2 1]
>>> # As can be seen from the above, y and idx shape
>>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode.
>>> class UniqueNet(nn.Cell):
... def __init__(self):
... super(UniqueNet, self).__init__()
...
... def construct(self, x):
... output, indices = ops.unique(x)
... return output, indices
...
>>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
>>> net = UniqueNet()
>>> output = net(input_x)
>>> print(output)
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
"""
unique_op = P.Unique()
reshape_op = P.Reshape()
shape_x = x.shape
length_x = get_x_shape(shape_x)
x = reshape_op(x, length_x)
y, idx = unique_op(x)
idx = reshape_op(idx, shape_x)
return y, idx

View File

@ -19,6 +19,7 @@ Primitive operator classes.
A collection of operators to build neural networks or to compute functions.
"""
from ..function.array_func import (unique)
from . import _quant_ops
from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
MapUniform, DynamicAssign, PadAndShift)
@ -127,6 +128,7 @@ from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, L
ConstrainForceVirial, ConstrainForce, Constrain)
__all__ = [
'unique',
'HSVToRGB',
'CeLU',
'Ger',

View File

@ -15,6 +15,7 @@
import numpy as np
import pytest
import mindspore.context as context
from mindspore import ops
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
@ -22,6 +23,7 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -30,6 +32,16 @@ class Net(nn.Cell):
def construct(self, x):
return self.unique(x)
class NetFunc(nn.Cell):
def __init__(self):
super(NetFunc, self).__init__()
self.unique = ops.unique
def construct(self, x):
return self.unique(x)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -42,3 +54,41 @@ def test_unqiue():
expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_1d():
"""
Feature: Test unique function
Description: Input 1D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_2d():
"""
Feature: Test unique function
Description: Input 2D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([[0, 0, 1], [1, 2, 2]])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

View File

@ -15,6 +15,7 @@
import numpy as np
import pytest
import mindspore.context as context
from mindspore import ops
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
@ -32,6 +33,15 @@ class Net(nn.Cell):
return self.unique(x)
class NetFunc(nn.Cell):
def __init__(self):
super(NetFunc, self).__init__()
self.unique = ops.unique
def construct(self, x):
return self.unique(x)
class UniqueSquare(nn.Cell):
def __init__(self):
super(UniqueSquare, self).__init__()
@ -67,3 +77,41 @@ def test_unique_square():
output = net(x)
expect1 = np.array([1, 4, 9])
assert (output.asnumpy() == expect1).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_1d():
"""
Feature: Test unique function
Description: Input 1D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_2d():
"""
Feature: Test unique function
Description: Input 2D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([[0, 0, 1], [1, 2, 2]])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

View File

@ -21,6 +21,7 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import Cell
from mindspore import ops
from mindspore.ops import operations as P
from mindspore.ops import prim_attr_register
from mindspore.ops.operations import _inner_ops as inner
@ -329,6 +330,26 @@ class TensorShapeNet(Cell):
return self.shape(x)
class UniqueFunc1(Cell):
def __init__(self):
super(UniqueFunc1, self).__init__()
self.unique = ops.unique
def construct(self, x):
y, idx = self.unique(x)
return y, idx
class UniqueFunc2(Cell):
def __init__(self):
super(UniqueFunc2, self).__init__()
self.unique = ops.unique
def construct(self, x):
y, idx = self.unique(x)
return y, idx
class RangeNet(Cell):
def __init__(self):
super(RangeNet, self).__init__()
@ -348,6 +369,12 @@ test_case_array_ops = [
('CustNet3', {
'block': CustNet3(),
'desc_inputs': []}),
('Unique', {
'block': UniqueFunc1(),
'desc_inputs': [Tensor(np.array([2, 2, 1]), dtype=ms.int32)]}),
('Unique', {
'block': UniqueFunc2(),
'desc_inputs': [Tensor(np.array([[2, 2], [1, 3]]), dtype=ms.int32)]}),
('MathBinaryNet1', {
'block': MathBinaryNet1(),
'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}),