unique support multi-dim tensor

This commit is contained in:
jiangzhenguang 2022-02-07 19:58:25 +08:00
parent 438a4081fb
commit 28f89f36b4
8 changed files with 251 additions and 4 deletions

View File

@ -5,6 +5,10 @@ mindspore.nn.ResizeBilinear
使用双线性插值调整输入Tensor为指定的大小。 使用双线性插值调整输入Tensor为指定的大小。
**参数:**
- **half_pixel_centers** (bool) - 是否几何中心对齐。如果设置为True, 那么`scale_factor`应该设置为False。默认值False。
**输入:** **输入:**
- **x** (Tensor) - ResizeBilinear的输入四维的Tensor其shape为 :math:`(batch, channels, height, width)` 数据类型为float16或float32。 - **x** (Tensor) - ResizeBilinear的输入四维的Tensor其shape为 :math:`(batch, channels, height, width)` 数据类型为float16或float32。

View File

@ -32,8 +32,8 @@ from mindspore._checkparam import Rel, Validator
from ..cell import Cell from ..cell import Cell
from .activation import get_activation from .activation import get_activation
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
'Tril', 'Triu', 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll'] 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll']
class L1Regularizer(Cell): class L1Regularizer(Cell):
@ -851,6 +851,10 @@ class ResizeBilinear(Cell):
r""" r"""
Samples the input tensor to the given size or scale_factor by using bilinear interpolate. Samples the input tensor to the given size or scale_factor by using bilinear interpolate.
Args:
half_pixel_centers (bool): Whether half pixel center. If set to True, `align_corners` should be False.
Default: False.
Inputs: Inputs:
- **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape - **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape
:math:`(batch, channels, height, width)`, with data type of float16 or float32. :math:`(batch, channels, height, width)`, with data type of float16 or float32.
@ -862,8 +866,6 @@ class ResizeBilinear(Cell):
- **align_corners** (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, which exactly - **align_corners** (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, which exactly
aligns the 4 corners of images and resized images. If false, rescale by :math:`new\_height / height`. aligns the 4 corners of images and resized images. If false, rescale by :math:`new\_height / height`.
Default: False. Default: False.
- **half_pixel_centers** (bool): Whether half pixel center. If set to True, `align_corners` should be False.
Default: False.
Outputs: Outputs:
Resized tensor. Resized tensor.

View File

@ -0,0 +1,14 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,100 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for function."""
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
@constexpr
def get_x_shape(x_shape):
s = 1
for i in x_shape:
s = s * i
return (s,)
def unique(x):
"""
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
tensor corresponding to the output unique tensor.
The output contains Tensor `y` and Tensor `idx`, the format is probably similar to (`y`, `idx`).
The shape of Tensor `y` and Tensor `idx` is different in most cases, because Tensor `y` will be deduplicated,
and the shape of Tensor `idx` is consistent with the input.
To get the same shape between `idx` and `y`, please ref to :class:'mindspore.ops.UniqueWithPad' operator.
.. warning::
This module is in beta.
Args:
x (Tensor): The input tensor.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Returns:
Tuple, containing Tensor objects `(y, idx), `y` is a tensor with the
same type as `input_x`, and contains the unique elements in `x`, sorted in
ascending order. `idx` is a tensor containing indices of elements in
the input corresponding to the output tensor, have the same shape with `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, nn
>>> from mindspore import ops
>>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
>>> output = ops.unique(input_x)
>>> print(output)
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
>>> y = output[0]
>>> print(y)
[1 2 5]
>>> idx = output[1]
>>> print(idx)
[0 1 2 1]
>>> # As can be seen from the above, y and idx shape
>>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode.
>>> class UniqueNet(nn.Cell):
... def __init__(self):
... super(UniqueNet, self).__init__()
...
... def construct(self, x):
... output, indices = ops.unique(x)
... return output, indices
...
>>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
>>> net = UniqueNet()
>>> output = net(input_x)
>>> print(output)
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
"""
unique_op = P.Unique()
reshape_op = P.Reshape()
shape_x = x.shape
length_x = get_x_shape(shape_x)
x = reshape_op(x, length_x)
y, idx = unique_op(x)
idx = reshape_op(idx, shape_x)
return y, idx

View File

@ -19,6 +19,7 @@ Primitive operator classes.
A collection of operators to build neural networks or to compute functions. A collection of operators to build neural networks or to compute functions.
""" """
from ..function.array_func import (unique)
from . import _quant_ops from . import _quant_ops
from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter, from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
MapUniform, DynamicAssign, PadAndShift) MapUniform, DynamicAssign, PadAndShift)
@ -127,6 +128,7 @@ from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, L
ConstrainForceVirial, ConstrainForce, Constrain) ConstrainForceVirial, ConstrainForce, Constrain)
__all__ = [ __all__ = [
'unique',
'HSVToRGB', 'HSVToRGB',
'CeLU', 'CeLU',
'Ger', 'Ger',

View File

@ -15,6 +15,7 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore.context as context import mindspore.context as context
from mindspore import ops
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -22,6 +23,7 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -30,6 +32,16 @@ class Net(nn.Cell):
def construct(self, x): def construct(self, x):
return self.unique(x) return self.unique(x)
class NetFunc(nn.Cell):
def __init__(self):
super(NetFunc, self).__init__()
self.unique = ops.unique
def construct(self, x):
return self.unique(x)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@ -42,3 +54,41 @@ def test_unqiue():
expect2 = np.array([0, 0, 1, 1, 2, 2]) expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all() assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all() assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_1d():
"""
Feature: Test unique function
Description: Input 1D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_2d():
"""
Feature: Test unique function
Description: Input 2D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([[0, 0, 1], [1, 2, 2]])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

View File

@ -15,6 +15,7 @@
import numpy as np import numpy as np
import pytest import pytest
import mindspore.context as context import mindspore.context as context
from mindspore import ops
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -32,6 +33,15 @@ class Net(nn.Cell):
return self.unique(x) return self.unique(x)
class NetFunc(nn.Cell):
def __init__(self):
super(NetFunc, self).__init__()
self.unique = ops.unique
def construct(self, x):
return self.unique(x)
class UniqueSquare(nn.Cell): class UniqueSquare(nn.Cell):
def __init__(self): def __init__(self):
super(UniqueSquare, self).__init__() super(UniqueSquare, self).__init__()
@ -67,3 +77,41 @@ def test_unique_square():
output = net(x) output = net(x)
expect1 = np.array([1, 4, 9]) expect1 = np.array([1, 4, 9])
assert (output.asnumpy() == expect1).all() assert (output.asnumpy() == expect1).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_1d():
"""
Feature: Test unique function
Description: Input 1D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([0, 0, 1, 1, 2, 2])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unqiue_func_2d():
"""
Feature: Test unique function
Description: Input 2D Tensor
Expectation: Successful execution.
"""
x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32)
unique = NetFunc()
output = unique(x)
expect1 = np.array([1, 2, 3])
expect2 = np.array([[0, 0, 1], [1, 2, 2]])
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

View File

@ -21,6 +21,7 @@ import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore import ops
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import prim_attr_register from mindspore.ops import prim_attr_register
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
@ -329,6 +330,26 @@ class TensorShapeNet(Cell):
return self.shape(x) return self.shape(x)
class UniqueFunc1(Cell):
def __init__(self):
super(UniqueFunc1, self).__init__()
self.unique = ops.unique
def construct(self, x):
y, idx = self.unique(x)
return y, idx
class UniqueFunc2(Cell):
def __init__(self):
super(UniqueFunc2, self).__init__()
self.unique = ops.unique
def construct(self, x):
y, idx = self.unique(x)
return y, idx
class RangeNet(Cell): class RangeNet(Cell):
def __init__(self): def __init__(self):
super(RangeNet, self).__init__() super(RangeNet, self).__init__()
@ -348,6 +369,12 @@ test_case_array_ops = [
('CustNet3', { ('CustNet3', {
'block': CustNet3(), 'block': CustNet3(),
'desc_inputs': []}), 'desc_inputs': []}),
('Unique', {
'block': UniqueFunc1(),
'desc_inputs': [Tensor(np.array([2, 2, 1]), dtype=ms.int32)]}),
('Unique', {
'block': UniqueFunc2(),
'desc_inputs': [Tensor(np.array([[2, 2], [1, 3]]), dtype=ms.int32)]}),
('MathBinaryNet1', { ('MathBinaryNet1', {
'block': MathBinaryNet1(), 'block': MathBinaryNet1(),
'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}), 'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}),