forked from mindspore-Ecosystem/mindspore
unique support multi-dim tensor
This commit is contained in:
parent
438a4081fb
commit
28f89f36b4
|
@ -5,6 +5,10 @@ mindspore.nn.ResizeBilinear
|
|||
|
||||
使用双线性插值调整输入Tensor为指定的大小。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **half_pixel_centers** (bool) - 是否几何中心对齐。如果设置为True, 那么`scale_factor`应该设置为False。默认值:False。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **x** (Tensor) - ResizeBilinear的输入,四维的Tensor,其shape为 :math:`(batch, channels, height, width)` ,数据类型为float16或float32。
|
||||
|
|
|
@ -32,8 +32,8 @@ from mindspore._checkparam import Rel, Validator
|
|||
from ..cell import Cell
|
||||
from .activation import get_activation
|
||||
|
||||
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
|
||||
'Tril', 'Triu', 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll']
|
||||
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
|
||||
'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll']
|
||||
|
||||
|
||||
class L1Regularizer(Cell):
|
||||
|
@ -851,6 +851,10 @@ class ResizeBilinear(Cell):
|
|||
r"""
|
||||
Samples the input tensor to the given size or scale_factor by using bilinear interpolate.
|
||||
|
||||
Args:
|
||||
half_pixel_centers (bool): Whether half pixel center. If set to True, `align_corners` should be False.
|
||||
Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape
|
||||
:math:`(batch, channels, height, width)`, with data type of float16 or float32.
|
||||
|
@ -862,8 +866,6 @@ class ResizeBilinear(Cell):
|
|||
- **align_corners** (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, which exactly
|
||||
aligns the 4 corners of images and resized images. If false, rescale by :math:`new\_height / height`.
|
||||
Default: False.
|
||||
- **half_pixel_centers** (bool): Whether half pixel center. If set to True, `align_corners` should be False.
|
||||
Default: False.
|
||||
|
||||
Outputs:
|
||||
Resized tensor.
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Operators for function."""
|
||||
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_x_shape(x_shape):
|
||||
s = 1
|
||||
for i in x_shape:
|
||||
s = s * i
|
||||
return (s,)
|
||||
|
||||
|
||||
def unique(x):
|
||||
"""
|
||||
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
|
||||
tensor corresponding to the output unique tensor.
|
||||
|
||||
The output contains Tensor `y` and Tensor `idx`, the format is probably similar to (`y`, `idx`).
|
||||
The shape of Tensor `y` and Tensor `idx` is different in most cases, because Tensor `y` will be deduplicated,
|
||||
and the shape of Tensor `idx` is consistent with the input.
|
||||
|
||||
To get the same shape between `idx` and `y`, please ref to :class:'mindspore.ops.UniqueWithPad' operator.
|
||||
|
||||
.. warning::
|
||||
This module is in beta.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
||||
|
||||
Returns:
|
||||
Tuple, containing Tensor objects `(y, idx), `y` is a tensor with the
|
||||
same type as `input_x`, and contains the unique elements in `x`, sorted in
|
||||
ascending order. `idx` is a tensor containing indices of elements in
|
||||
the input corresponding to the output tensor, have the same shape with `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor, nn
|
||||
>>> from mindspore import ops
|
||||
>>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
|
||||
>>> output = ops.unique(input_x)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
|
||||
>>> y = output[0]
|
||||
>>> print(y)
|
||||
[1 2 5]
|
||||
>>> idx = output[1]
|
||||
>>> print(idx)
|
||||
[0 1 2 1]
|
||||
>>> # As can be seen from the above, y and idx shape
|
||||
>>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode.
|
||||
>>> class UniqueNet(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(UniqueNet, self).__init__()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... output, indices = ops.unique(x)
|
||||
... return output, indices
|
||||
...
|
||||
>>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
|
||||
>>> net = UniqueNet()
|
||||
>>> output = net(input_x)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
|
||||
"""
|
||||
|
||||
unique_op = P.Unique()
|
||||
reshape_op = P.Reshape()
|
||||
|
||||
shape_x = x.shape
|
||||
length_x = get_x_shape(shape_x)
|
||||
x = reshape_op(x, length_x)
|
||||
y, idx = unique_op(x)
|
||||
idx = reshape_op(idx, shape_x)
|
||||
return y, idx
|
|
@ -19,6 +19,7 @@ Primitive operator classes.
|
|||
A collection of operators to build neural networks or to compute functions.
|
||||
"""
|
||||
|
||||
from ..function.array_func import (unique)
|
||||
from . import _quant_ops
|
||||
from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
|
||||
MapUniform, DynamicAssign, PadAndShift)
|
||||
|
@ -127,6 +128,7 @@ from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, L
|
|||
ConstrainForceVirial, ConstrainForce, Constrain)
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'HSVToRGB',
|
||||
'CeLU',
|
||||
'Ger',
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -22,6 +23,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -30,6 +32,16 @@ class Net(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.unique(x)
|
||||
|
||||
|
||||
class NetFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetFunc, self).__init__()
|
||||
self.unique = ops.unique
|
||||
|
||||
def construct(self, x):
|
||||
return self.unique(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -42,3 +54,41 @@ def test_unqiue():
|
|||
expect2 = np.array([0, 0, 1, 1, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unqiue_func_1d():
|
||||
"""
|
||||
Feature: Test unique function
|
||||
Description: Input 1D Tensor
|
||||
Expectation: Successful execution.
|
||||
"""
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
|
||||
unique = NetFunc()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([0, 0, 1, 1, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unqiue_func_2d():
|
||||
"""
|
||||
Feature: Test unique function
|
||||
Description: Input 2D Tensor
|
||||
Expectation: Successful execution.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32)
|
||||
unique = NetFunc()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([[0, 0, 1], [1, 2, 2]])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -32,6 +33,15 @@ class Net(nn.Cell):
|
|||
return self.unique(x)
|
||||
|
||||
|
||||
class NetFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetFunc, self).__init__()
|
||||
self.unique = ops.unique
|
||||
|
||||
def construct(self, x):
|
||||
return self.unique(x)
|
||||
|
||||
|
||||
class UniqueSquare(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueSquare, self).__init__()
|
||||
|
@ -67,3 +77,41 @@ def test_unique_square():
|
|||
output = net(x)
|
||||
expect1 = np.array([1, 4, 9])
|
||||
assert (output.asnumpy() == expect1).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unqiue_func_1d():
|
||||
"""
|
||||
Feature: Test unique function
|
||||
Description: Input 1D Tensor
|
||||
Expectation: Successful execution.
|
||||
"""
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
|
||||
unique = NetFunc()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([0, 0, 1, 1, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unqiue_func_2d():
|
||||
"""
|
||||
Feature: Test unique function
|
||||
Description: Input 2D Tensor
|
||||
Expectation: Successful execution.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32)
|
||||
unique = NetFunc()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([[0, 0, 1], [1, 2, 2]])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.context as context
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import prim_attr_register
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
@ -329,6 +330,26 @@ class TensorShapeNet(Cell):
|
|||
return self.shape(x)
|
||||
|
||||
|
||||
class UniqueFunc1(Cell):
|
||||
def __init__(self):
|
||||
super(UniqueFunc1, self).__init__()
|
||||
self.unique = ops.unique
|
||||
|
||||
def construct(self, x):
|
||||
y, idx = self.unique(x)
|
||||
return y, idx
|
||||
|
||||
|
||||
class UniqueFunc2(Cell):
|
||||
def __init__(self):
|
||||
super(UniqueFunc2, self).__init__()
|
||||
self.unique = ops.unique
|
||||
|
||||
def construct(self, x):
|
||||
y, idx = self.unique(x)
|
||||
return y, idx
|
||||
|
||||
|
||||
class RangeNet(Cell):
|
||||
def __init__(self):
|
||||
super(RangeNet, self).__init__()
|
||||
|
@ -348,6 +369,12 @@ test_case_array_ops = [
|
|||
('CustNet3', {
|
||||
'block': CustNet3(),
|
||||
'desc_inputs': []}),
|
||||
('Unique', {
|
||||
'block': UniqueFunc1(),
|
||||
'desc_inputs': [Tensor(np.array([2, 2, 1]), dtype=ms.int32)]}),
|
||||
('Unique', {
|
||||
'block': UniqueFunc2(),
|
||||
'desc_inputs': [Tensor(np.array([[2, 2], [1, 3]]), dtype=ms.int32)]}),
|
||||
('MathBinaryNet1', {
|
||||
'block': MathBinaryNet1(),
|
||||
'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}),
|
||||
|
|
Loading…
Reference in New Issue