From 8031beb260aff8ef1d7a746e5da730eecdc620fd Mon Sep 17 00:00:00 2001 From: guozhibin Date: Thu, 12 Jan 2023 15:11:50 +0800 Subject: [PATCH] add ops.swapaxes and ops.swapdims --- docs/api/api_python/mindspore.ops.rst | 2 + .../ops/mindspore.ops.func_swapaxes.rst | 19 ++++ .../ops/mindspore.ops.func_swapdims.rst | 20 +++++ docs/api/api_python_en/mindspore.ops.rst | 2 + mindspore/python/mindspore/_checkparam.py | 6 +- mindspore/python/mindspore/common/tensor.py | 17 +--- .../python/mindspore/ops/function/__init__.py | 2 + .../mindspore/ops/function/array_func.py | 90 +++++++++++++++++++ mindspore/python/mindspore/ops/functional.py | 1 + tests/st/ops/test_func_swapaxes.py | 63 +++++++++++++ tests/st/ops/test_func_swapdims.py | 63 +++++++++++++ 11 files changed, 266 insertions(+), 19 deletions(-) create mode 100644 docs/api/api_python/ops/mindspore.ops.func_swapaxes.rst create mode 100644 docs/api/api_python/ops/mindspore.ops.func_swapdims.rst create mode 100644 tests/st/ops/test_func_swapaxes.py create mode 100644 tests/st/ops/test_func_swapdims.py diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index 5030664f813..61edecc4f49 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -502,6 +502,8 @@ Array操作 mindspore.ops.stack mindspore.ops.strided_slice mindspore.ops.sum + mindspore.ops.swapaxes + mindspore.ops.swapdims mindspore.ops.tensor_scatter_add mindspore.ops.tensor_scatter_div mindspore.ops.tensor_scatter_max diff --git a/docs/api/api_python/ops/mindspore.ops.func_swapaxes.rst b/docs/api/api_python/ops/mindspore.ops.func_swapaxes.rst new file mode 100644 index 00000000000..4be7603cc22 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_swapaxes.rst @@ -0,0 +1,19 @@ +mindspore.ops.swapaxes +======================= + +.. py:function:: mindspore.ops.swapaxes(x, axis0, axis1) + + 交换Tensor的两个维度。 + + 参数: + - **x** (Tensor) - 输入Tensor。 + - **axis0** (int) - 第一个维度。 + - **axis1** (int) - 第二个维度。 + + 返回: + 转化后的Tensor,与输入具有相同的数据类型。 + + 异常: + - **TypeError** - `x` 不是Tensor类型。 + - **TypeError** - `axis0` 或 `axis1` 不是整数。 + - **ValueError** - `axis0` 或 `axis1` 不在 `[-ndim, ndim-1]` 范围内。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.func_swapdims.rst b/docs/api/api_python/ops/mindspore.ops.func_swapdims.rst new file mode 100644 index 00000000000..b4291f34799 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_swapdims.rst @@ -0,0 +1,20 @@ +mindspore.ops.swapdims +======================= + +.. py:function:: mindspore.ops.swapdims(x, dim0, dim1) + + 交换Tensor的两个维度。 + 该函数和 :func:`mindspore.ops.swapaxes` 功能一致。 + + 参数: + - **x** (Tensor) - 输入Tensor。 + - **dim0** (int) - 第一个维度。 + - **dim1** (int) - 第二个维度。 + + 返回: + 转化后的Tensor,与输入具有相同的数据类型。 + + 异常: + - **TypeError** - `x` 不是Tensor类型。 + - **TypeError** - `dim0` 或 `dim1` 不是整数。 + - **ValueError** - `dim0` 或 `dim1` 不在 `[-ndim, ndim-1]` 范围内。 \ No newline at end of file diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index e4ebfc6000a..8f2f1ae5bac 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -502,6 +502,8 @@ Array Operation mindspore.ops.stack mindspore.ops.strided_slice mindspore.ops.sum + mindspore.ops.swapaxes + mindspore.ops.swapdims mindspore.ops.tensor_scatter_add mindspore.ops.tensor_scatter_min mindspore.ops.tensor_scatter_max diff --git a/mindspore/python/mindspore/_checkparam.py b/mindspore/python/mindspore/_checkparam.py index 87442b48e7b..ef7025014bd 100644 --- a/mindspore/python/mindspore/_checkparam.py +++ b/mindspore/python/mindspore/_checkparam.py @@ -848,18 +848,18 @@ class Validator: @staticmethod def check_swapaxes_axis(axes, ndim): - """Check all the axes argument for tensor.swapaxes""" + """Check all the axes argument for ops.swapaxes""" if isinstance(axes, int): Validator.check_axis_in_range(axes, ndim) return axes % ndim if isinstance(axes, (tuple, list)): for axis in axes: if not isinstance(axis, int): - raise TypeError(f"For Tensor.swapaxes, the axis argument must be integer, but got {type(axis)}.") + raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.") Validator.check_axis_in_range(axis, ndim) axes = tuple(map(lambda x: x % ndim, axes)) return axes - raise TypeError(f"For Tensor.swapaxes, the argument 'axes' must be integer, list or tuple for check, " + raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " f"but got {type(axes)}.") @staticmethod diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 6aa2444d630..9537b87a00d 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -1819,22 +1819,7 @@ class Tensor(Tensor_): (4,3,2) """ self._init_check() - axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim) - - if axis1 == axis2: - return self - if axis1 > axis2: - axis1, axis2 = axis2, axis1 - - perm = tuple(range(0, self.ndim)) - if axis2 + 1 < self.ndim: - new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ - perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] - else: - new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ - perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] - - return tensor_operator_registry.get('transpose')()(self, new_perm) + return tensor_operator_registry.get('swapaxes')(self, axis1, axis2) def squeeze(self, axis=None): """ diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 5efc8ed0aa6..3c67f8c1f8f 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -141,6 +141,8 @@ from .array_func import ( lstsq, mvlgamma, argsort, + swapaxes, + swapdims, sequence_mask, repeat_elements, repeat_interleave, diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 9ee2d6cbafe..e0e971963bc 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -6046,6 +6046,94 @@ def count_nonzero(x, dims=None): return count_nonzero_(x) +#TODO: remove comment +@constexpr +def _check_swapaxes_axis(axes, ndim): + return validator.check_swapaxes_axis(axes, ndim) + + +def swapaxes(x, axis0, axis1): + ''' + Interchange two axes of a tensor. + + Args: + x(Tensor): Input tensor. + axis0 (int): First axis. + axis1 (int): Second axis. + + Returns: + Transposed tensor, has the same data type as `x`. + + Raises: + TypeError: If argument `x` is not Tensor. + TypeError: If `axis0` or `axis1` is not integer. + ValueError: If `axis0` or `axis1` is not in the range of :math:`[-ndim, ndim-1]`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> import mindspore.ops as ops + >>> from mindspore import Tensor + >>> x = Tensor(np.ones((2,3,4), dtype=np.float32)) + >>> output = ops.swapaxes(x, 0, 2) + >>> print(output.shape) + (4,3,2) + ''' + if not isinstance(x, Tensor): + raise TypeError(f'For ops.swapaxes, parameter `x` must be Tensor, but got {type(x)}') + + axis0, axis1 = _check_swapaxes_axis((axis0, axis1), x.ndim) + if axis0 == axis1: + return x + if axis0 > axis1: + axis0, axis1 = axis1, axis0 + + perm = tuple(builtins.range(0, x.ndim)) + if axis1 + 1 < x.ndim: + new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \ + perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:] + else: + new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \ + perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + + return _get_cache_prim(P.Transpose)()(x, new_perm) + + +def swapdims(x, dim0, dim1): + ''' + Interchange two dims of a tensor. + This function is equivalent to :func:`mindspore.ops.swapaxes` function. + + Args: + x(Tensor): Input tensor. + dim0 (int): First dim. + dim1 (int): Second dim. + + Returns: + Transposed tensor, has the same data type as `x`. + + Raises: + TypeError: If argument `x` is not Tensor. + TypeError: If `dim0` or `dim1` is not integer. + ValueError: If `dim0` or `dim1` is not in the range of :math:`[-ndim, ndim-1]`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> import mindspore.ops as ops + >>> from mindspore import Tensor + >>> x = Tensor(np.ones((2,3,4), dtype=np.float32)) + >>> output = ops.swapdims(x, 0, 2) + >>> print(output.shape) + (4,3,2) + ''' + return F.swapaxes(x, dim0, dim1) + + @constexpr def _check_is_int(arg_value, arg_name, op_name): arg_value = validator.check_is_int(arg_value, arg_name, op_name) @@ -6352,6 +6440,8 @@ __all__ = [ 'diagonal', 'lstsq', 'mvlgamma', + 'swapaxes', + 'swapdims', 'argsort', 'sequence_mask', 'repeat_elements', diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index c221d89679d..7adc845f5cd 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -398,6 +398,7 @@ tensor_operator_registry.register('deg2rad', deg2rad) tensor_operator_registry.register('copysign', copysign) tensor_operator_registry.register('roll', Roll) tensor_operator_registry.register('rot90', rot90) +tensor_operator_registry.register('swapaxes', swapaxes) tensor_operator_registry.register('repeat_elements', repeat_elements) __all__ = [name for name in dir() if name[0] != "_"] diff --git a/tests/st/ops/test_func_swapaxes.py b/tests/st/ops/test_func_swapaxes.py new file mode 100644 index 00000000000..456be3820b3 --- /dev/null +++ b/tests/st/ops/test_func_swapaxes.py @@ -0,0 +1,63 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore as ms +from mindspore import ops +import mindspore.nn as nn +from tests.st.numpy_native.utils import check_all_results, to_tensor + + +class NetSwapAxes(nn.Cell): + def construct(self, x, axis0, axis1): + return ops.swapaxes(x, axis0=axis0, axis1=axis1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_swapaxes(mode): + """ + Feature: swapaxes + Description: Verify the result of swapaxes + Expectation: success. + """ + ms.set_context(mode=mode) + np_array = np.random.random((3, 4, 5)).astype('float32') + + np_swap_output = [] + np_swap_output.append(np.swapaxes(np_array, 0, 1)) + np_swap_output.append(np.swapaxes(np_array, 1, 0)) + np_swap_output.append(np.swapaxes(np_array, 1, 1)) + np_swap_output.append(np.swapaxes(np_array, 2, 1)) + np_swap_output.append(np.swapaxes(np_array, 1, 2)) + np_swap_output.append(np.swapaxes(np_array, 2, 2)) + + swapaxes_op = NetSwapAxes() + op_swap_output = [] + ms_array = to_tensor(np_array) + op_swap_output.append(swapaxes_op(ms_array, 0, 1)) + op_swap_output.append(swapaxes_op(ms_array, 1, 0)) + op_swap_output.append(swapaxes_op(ms_array, 1, 1)) + op_swap_output.append(swapaxes_op(ms_array, 2, 1)) + op_swap_output.append(swapaxes_op(ms_array, 1, 2)) + op_swap_output.append(swapaxes_op(ms_array, 2, 2)) + + check_all_results(np_swap_output, op_swap_output) diff --git a/tests/st/ops/test_func_swapdims.py b/tests/st/ops/test_func_swapdims.py new file mode 100644 index 00000000000..6ce231e4deb --- /dev/null +++ b/tests/st/ops/test_func_swapdims.py @@ -0,0 +1,63 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore as ms +from mindspore import ops +import mindspore.nn as nn +from tests.st.numpy_native.utils import check_all_results, to_tensor + + +class NetSwapDims(nn.Cell): + def construct(self, x, dim0, dim1): + return ops.swapdims(x, dim0=dim0, dim1=dim1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_swapdims(mode): + """ + Feature: swapdims + Description: Verify the result of swapdims + Expectation: success. + """ + ms.set_context(mode=mode) + np_array = np.random.random((3, 4, 5)).astype('float32') + + np_swap_output = [] + np_swap_output.append(np.swapaxes(np_array, 0, 1)) + np_swap_output.append(np.swapaxes(np_array, 1, 0)) + np_swap_output.append(np.swapaxes(np_array, 1, 1)) + np_swap_output.append(np.swapaxes(np_array, 2, 1)) + np_swap_output.append(np.swapaxes(np_array, 1, 2)) + np_swap_output.append(np.swapaxes(np_array, 2, 2)) + + swapdims_op = NetSwapDims() + op_swap_output = [] + ms_array = to_tensor(np_array) + op_swap_output.append(swapdims_op(ms_array, 0, 1)) + op_swap_output.append(swapdims_op(ms_array, 1, 0)) + op_swap_output.append(swapdims_op(ms_array, 1, 1)) + op_swap_output.append(swapdims_op(ms_array, 2, 1)) + op_swap_output.append(swapdims_op(ms_array, 1, 2)) + op_swap_output.append(swapdims_op(ms_array, 2, 2)) + + check_all_results(np_swap_output, op_swap_output)