diff --git a/docs/api/api_python/nn/mindspore.nn.ResizeBilinear.rst b/docs/api/api_python/nn/mindspore.nn.ResizeBilinear.rst index e6f8d3c1bf0..21127566aea 100644 --- a/docs/api/api_python/nn/mindspore.nn.ResizeBilinear.rst +++ b/docs/api/api_python/nn/mindspore.nn.ResizeBilinear.rst @@ -5,6 +5,10 @@ mindspore.nn.ResizeBilinear 使用双线性插值调整输入Tensor为指定的大小。 + **参数:** + + - **half_pixel_centers** (bool) - 是否几何中心对齐。如果设置为True, 那么`scale_factor`应该设置为False。默认值:False。 + **输入:** - **x** (Tensor) - ResizeBilinear的输入,四维的Tensor,其shape为 :math:`(batch, channels, height, width)` ,数据类型为float16或float32。 diff --git a/mindspore/python/mindspore/nn/layer/basic.py b/mindspore/python/mindspore/nn/layer/basic.py index c2cfcd9b2c1..c948efd077c 100644 --- a/mindspore/python/mindspore/nn/layer/basic.py +++ b/mindspore/python/mindspore/nn/layer/basic.py @@ -32,8 +32,8 @@ from mindspore._checkparam import Rel, Validator from ..cell import Cell from .activation import get_activation -__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', - 'Tril', 'Triu', 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll'] +__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu', + 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll'] class L1Regularizer(Cell): @@ -851,6 +851,10 @@ class ResizeBilinear(Cell): r""" Samples the input tensor to the given size or scale_factor by using bilinear interpolate. + Args: + half_pixel_centers (bool): Whether half pixel center. If set to True, `align_corners` should be False. + Default: False. + Inputs: - **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape :math:`(batch, channels, height, width)`, with data type of float16 or float32. @@ -862,8 +866,6 @@ class ResizeBilinear(Cell): - **align_corners** (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, which exactly aligns the 4 corners of images and resized images. If false, rescale by :math:`new\_height / height`. Default: False. - - **half_pixel_centers** (bool): Whether half pixel center. If set to True, `align_corners` should be False. - Default: False. Outputs: Resized tensor. diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py new file mode 100644 index 00000000000..602527cd720 --- /dev/null +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py new file mode 100644 index 00000000000..73b99e87cfe --- /dev/null +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -0,0 +1,100 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Operators for function.""" + +from mindspore.ops.primitive import constexpr +from mindspore.ops import operations as P + + +@constexpr +def get_x_shape(x_shape): + s = 1 + for i in x_shape: + s = s * i + return (s,) + + +def unique(x): + """ + Returns the unique elements of input tensor and also return a tensor containing the index of each value of input + tensor corresponding to the output unique tensor. + + The output contains Tensor `y` and Tensor `idx`, the format is probably similar to (`y`, `idx`). + The shape of Tensor `y` and Tensor `idx` is different in most cases, because Tensor `y` will be deduplicated, + and the shape of Tensor `idx` is consistent with the input. + + To get the same shape between `idx` and `y`, please ref to :class:'mindspore.ops.UniqueWithPad' operator. + + .. warning:: + This module is in beta. + + Args: + x (Tensor): The input tensor. + The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions. + + Returns: + Tuple, containing Tensor objects `(y, idx), `y` is a tensor with the + same type as `input_x`, and contains the unique elements in `x`, sorted in + ascending order. `idx` is a tensor containing indices of elements in + the input corresponding to the output tensor, have the same shape with `input_x`. + + Raises: + TypeError: If `input_x` is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore + >>> import numpy as np + >>> from mindspore import Tensor, nn + >>> from mindspore import ops + >>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32) + >>> output = ops.unique(input_x) + >>> print(output) + (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1])) + >>> y = output[0] + >>> print(y) + [1 2 5] + >>> idx = output[1] + >>> print(idx) + [0 1 2 1] + >>> # As can be seen from the above, y and idx shape + >>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode. + >>> class UniqueNet(nn.Cell): + ... def __init__(self): + ... super(UniqueNet, self).__init__() + ... + ... def construct(self, x): + ... output, indices = ops.unique(x) + ... return output, indices + ... + >>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32) + >>> net = UniqueNet() + >>> output = net(input_x) + >>> print(output) + (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1])) + """ + + unique_op = P.Unique() + reshape_op = P.Reshape() + + shape_x = x.shape + length_x = get_x_shape(shape_x) + x = reshape_op(x, length_x) + y, idx = unique_op(x) + idx = reshape_op(idx, shape_x) + return y, idx diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 04b8afcc786..267eb845774 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -19,6 +19,7 @@ Primitive operator classes. A collection of operators to build neural networks or to compute functions. """ +from ..function.array_func import (unique) from . import _quant_ops from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter, MapUniform, DynamicAssign, PadAndShift) @@ -127,6 +128,7 @@ from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, L ConstrainForceVirial, ConstrainForce, Constrain) __all__ = [ + 'unique', 'HSVToRGB', 'CeLU', 'Ger', diff --git a/tests/st/dynamic_shape/test_unique.py b/tests/st/dynamic_shape/test_unique.py index 65461f1a83f..e3d1a51ae10 100644 --- a/tests/st/dynamic_shape/test_unique.py +++ b/tests/st/dynamic_shape/test_unique.py @@ -15,6 +15,7 @@ import numpy as np import pytest import mindspore.context as context +from mindspore import ops import mindspore.nn as nn from mindspore import Tensor import mindspore.common.dtype as mstype @@ -22,6 +23,7 @@ from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -30,6 +32,16 @@ class Net(nn.Cell): def construct(self, x): return self.unique(x) + +class NetFunc(nn.Cell): + def __init__(self): + super(NetFunc, self).__init__() + self.unique = ops.unique + + def construct(self, x): + return self.unique(x) + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -42,3 +54,41 @@ def test_unqiue(): expect2 = np.array([0, 0, 1, 1, 2, 2]) assert (output[0].asnumpy() == expect1).all() assert (output[1].asnumpy() == expect2).all() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unqiue_func_1d(): + """ + Feature: Test unique function + Description: Input 1D Tensor + Expectation: Successful execution. + """ + x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32) + unique = NetFunc() + output = unique(x) + expect1 = np.array([1, 2, 3]) + expect2 = np.array([0, 0, 1, 1, 2, 2]) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unqiue_func_2d(): + """ + Feature: Test unique function + Description: Input 2D Tensor + Expectation: Successful execution. + """ + x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32) + unique = NetFunc() + output = unique(x) + expect1 = np.array([1, 2, 3]) + expect2 = np.array([[0, 0, 1], [1, 2, 2]]) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() diff --git a/tests/st/dynamic_shape/test_unique_cpu.py b/tests/st/dynamic_shape/test_unique_cpu.py index 6e2b57ab099..446f1cd7dfb 100644 --- a/tests/st/dynamic_shape/test_unique_cpu.py +++ b/tests/st/dynamic_shape/test_unique_cpu.py @@ -15,6 +15,7 @@ import numpy as np import pytest import mindspore.context as context +from mindspore import ops import mindspore.nn as nn from mindspore import Tensor import mindspore.common.dtype as mstype @@ -32,6 +33,15 @@ class Net(nn.Cell): return self.unique(x) +class NetFunc(nn.Cell): + def __init__(self): + super(NetFunc, self).__init__() + self.unique = ops.unique + + def construct(self, x): + return self.unique(x) + + class UniqueSquare(nn.Cell): def __init__(self): super(UniqueSquare, self).__init__() @@ -67,3 +77,41 @@ def test_unique_square(): output = net(x) expect1 = np.array([1, 4, 9]) assert (output.asnumpy() == expect1).all() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unqiue_func_1d(): + """ + Feature: Test unique function + Description: Input 1D Tensor + Expectation: Successful execution. + """ + x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32) + unique = NetFunc() + output = unique(x) + expect1 = np.array([1, 2, 3]) + expect2 = np.array([0, 0, 1, 1, 2, 2]) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unqiue_func_2d(): + """ + Feature: Test unique function + Description: Input 2D Tensor + Expectation: Successful execution. + """ + x = Tensor(np.array([[1, 1, 2], [2, 3, 3]]), mstype.int32) + unique = NetFunc() + output = unique(x) + expect1 = np.array([1, 2, 3]) + expect2 = np.array([[0, 0, 1], [1, 2, 2]]) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 2c971e1af99..9fa0c7d3a98 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -21,6 +21,7 @@ import mindspore.context as context from mindspore import Tensor from mindspore.common import dtype as mstype from mindspore.nn import Cell +from mindspore import ops from mindspore.ops import operations as P from mindspore.ops import prim_attr_register from mindspore.ops.operations import _inner_ops as inner @@ -329,6 +330,26 @@ class TensorShapeNet(Cell): return self.shape(x) +class UniqueFunc1(Cell): + def __init__(self): + super(UniqueFunc1, self).__init__() + self.unique = ops.unique + + def construct(self, x): + y, idx = self.unique(x) + return y, idx + + +class UniqueFunc2(Cell): + def __init__(self): + super(UniqueFunc2, self).__init__() + self.unique = ops.unique + + def construct(self, x): + y, idx = self.unique(x) + return y, idx + + class RangeNet(Cell): def __init__(self): super(RangeNet, self).__init__() @@ -348,6 +369,12 @@ test_case_array_ops = [ ('CustNet3', { 'block': CustNet3(), 'desc_inputs': []}), + ('Unique', { + 'block': UniqueFunc1(), + 'desc_inputs': [Tensor(np.array([2, 2, 1]), dtype=ms.int32)]}), + ('Unique', { + 'block': UniqueFunc2(), + 'desc_inputs': [Tensor(np.array([[2, 2], [1, 3]]), dtype=ms.int32)]}), ('MathBinaryNet1', { 'block': MathBinaryNet1(), 'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}),