add ops.swapaxes and ops.swapdims

This commit is contained in:
guozhibin 2023-01-12 15:11:50 +08:00
parent 2ff013c3cc
commit 8031beb260
11 changed files with 266 additions and 19 deletions

View File

@ -502,6 +502,8 @@ Array操作
mindspore.ops.stack
mindspore.ops.strided_slice
mindspore.ops.sum
mindspore.ops.swapaxes
mindspore.ops.swapdims
mindspore.ops.tensor_scatter_add
mindspore.ops.tensor_scatter_div
mindspore.ops.tensor_scatter_max

View File

@ -0,0 +1,19 @@
mindspore.ops.swapaxes
=======================
.. py:function:: mindspore.ops.swapaxes(x, axis0, axis1)
交换Tensor的两个维度。
参数:
- **x** (Tensor) - 输入Tensor。
- **axis0** (int) - 第一个维度。
- **axis1** (int) - 第二个维度。
返回:
转化后的Tensor与输入具有相同的数据类型。
异常:
- **TypeError** - `x` 不是Tensor类型。
- **TypeError** - `axis0``axis1` 不是整数。
- **ValueError** - `axis0``axis1` 不在 `[-ndim, ndim-1]` 范围内。

View File

@ -0,0 +1,20 @@
mindspore.ops.swapdims
=======================
.. py:function:: mindspore.ops.swapdims(x, dim0, dim1)
交换Tensor的两个维度。
该函数和 :func:`mindspore.ops.swapaxes` 功能一致。
参数:
- **x** (Tensor) - 输入Tensor。
- **dim0** (int) - 第一个维度。
- **dim1** (int) - 第二个维度。
返回:
转化后的Tensor与输入具有相同的数据类型。
异常:
- **TypeError** - `x` 不是Tensor类型。
- **TypeError** - `dim0``dim1` 不是整数。
- **ValueError** - `dim0``dim1` 不在 `[-ndim, ndim-1]` 范围内。

View File

@ -502,6 +502,8 @@ Array Operation
mindspore.ops.stack
mindspore.ops.strided_slice
mindspore.ops.sum
mindspore.ops.swapaxes
mindspore.ops.swapdims
mindspore.ops.tensor_scatter_add
mindspore.ops.tensor_scatter_min
mindspore.ops.tensor_scatter_max

View File

@ -848,18 +848,18 @@ class Validator:
@staticmethod
def check_swapaxes_axis(axes, ndim):
"""Check all the axes argument for tensor.swapaxes"""
"""Check all the axes argument for ops.swapaxes"""
if isinstance(axes, int):
Validator.check_axis_in_range(axes, ndim)
return axes % ndim
if isinstance(axes, (tuple, list)):
for axis in axes:
if not isinstance(axis, int):
raise TypeError(f"For Tensor.swapaxes, the axis argument must be integer, but got {type(axis)}.")
raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
Validator.check_axis_in_range(axis, ndim)
axes = tuple(map(lambda x: x % ndim, axes))
return axes
raise TypeError(f"For Tensor.swapaxes, the argument 'axes' must be integer, list or tuple for check, "
raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, "
f"but got {type(axes)}.")
@staticmethod

View File

@ -1819,22 +1819,7 @@ class Tensor(Tensor_):
(4,3,2)
"""
self._init_check()
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
if axis1 == axis2:
return self
if axis1 > axis2:
axis1, axis2 = axis2, axis1
perm = tuple(range(0, self.ndim))
if axis2 + 1 < self.ndim:
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
else:
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
return tensor_operator_registry.get('transpose')()(self, new_perm)
return tensor_operator_registry.get('swapaxes')(self, axis1, axis2)
def squeeze(self, axis=None):
"""

View File

@ -141,6 +141,8 @@ from .array_func import (
lstsq,
mvlgamma,
argsort,
swapaxes,
swapdims,
sequence_mask,
repeat_elements,
repeat_interleave,

View File

@ -6046,6 +6046,94 @@ def count_nonzero(x, dims=None):
return count_nonzero_(x)
#TODO: remove comment
@constexpr
def _check_swapaxes_axis(axes, ndim):
return validator.check_swapaxes_axis(axes, ndim)
def swapaxes(x, axis0, axis1):
'''
Interchange two axes of a tensor.
Args:
x(Tensor): Input tensor.
axis0 (int): First axis.
axis1 (int): Second axis.
Returns:
Transposed tensor, has the same data type as `x`.
Raises:
TypeError: If argument `x` is not Tensor.
TypeError: If `axis0` or `axis1` is not integer.
ValueError: If `axis0` or `axis1` is not in the range of :math:`[-ndim, ndim-1]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = ops.swapaxes(x, 0, 2)
>>> print(output.shape)
(4,3,2)
'''
if not isinstance(x, Tensor):
raise TypeError(f'For ops.swapaxes, parameter `x` must be Tensor, but got {type(x)}')
axis0, axis1 = _check_swapaxes_axis((axis0, axis1), x.ndim)
if axis0 == axis1:
return x
if axis0 > axis1:
axis0, axis1 = axis1, axis0
perm = tuple(builtins.range(0, x.ndim))
if axis1 + 1 < x.ndim:
new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:]
else:
new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1]
return _get_cache_prim(P.Transpose)()(x, new_perm)
def swapdims(x, dim0, dim1):
'''
Interchange two dims of a tensor.
This function is equivalent to :func:`mindspore.ops.swapaxes` function.
Args:
x(Tensor): Input tensor.
dim0 (int): First dim.
dim1 (int): Second dim.
Returns:
Transposed tensor, has the same data type as `x`.
Raises:
TypeError: If argument `x` is not Tensor.
TypeError: If `dim0` or `dim1` is not integer.
ValueError: If `dim0` or `dim1` is not in the range of :math:`[-ndim, ndim-1]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = ops.swapdims(x, 0, 2)
>>> print(output.shape)
(4,3,2)
'''
return F.swapaxes(x, dim0, dim1)
@constexpr
def _check_is_int(arg_value, arg_name, op_name):
arg_value = validator.check_is_int(arg_value, arg_name, op_name)
@ -6352,6 +6440,8 @@ __all__ = [
'diagonal',
'lstsq',
'mvlgamma',
'swapaxes',
'swapdims',
'argsort',
'sequence_mask',
'repeat_elements',

View File

@ -398,6 +398,7 @@ tensor_operator_registry.register('deg2rad', deg2rad)
tensor_operator_registry.register('copysign', copysign)
tensor_operator_registry.register('roll', Roll)
tensor_operator_registry.register('rot90', rot90)
tensor_operator_registry.register('swapaxes', swapaxes)
tensor_operator_registry.register('repeat_elements', repeat_elements)
__all__ = [name for name in dir() if name[0] != "_"]

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import ops
import mindspore.nn as nn
from tests.st.numpy_native.utils import check_all_results, to_tensor
class NetSwapAxes(nn.Cell):
def construct(self, x, axis0, axis1):
return ops.swapaxes(x, axis0=axis0, axis1=axis1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_swapaxes(mode):
"""
Feature: swapaxes
Description: Verify the result of swapaxes
Expectation: success.
"""
ms.set_context(mode=mode)
np_array = np.random.random((3, 4, 5)).astype('float32')
np_swap_output = []
np_swap_output.append(np.swapaxes(np_array, 0, 1))
np_swap_output.append(np.swapaxes(np_array, 1, 0))
np_swap_output.append(np.swapaxes(np_array, 1, 1))
np_swap_output.append(np.swapaxes(np_array, 2, 1))
np_swap_output.append(np.swapaxes(np_array, 1, 2))
np_swap_output.append(np.swapaxes(np_array, 2, 2))
swapaxes_op = NetSwapAxes()
op_swap_output = []
ms_array = to_tensor(np_array)
op_swap_output.append(swapaxes_op(ms_array, 0, 1))
op_swap_output.append(swapaxes_op(ms_array, 1, 0))
op_swap_output.append(swapaxes_op(ms_array, 1, 1))
op_swap_output.append(swapaxes_op(ms_array, 2, 1))
op_swap_output.append(swapaxes_op(ms_array, 1, 2))
op_swap_output.append(swapaxes_op(ms_array, 2, 2))
check_all_results(np_swap_output, op_swap_output)

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import ops
import mindspore.nn as nn
from tests.st.numpy_native.utils import check_all_results, to_tensor
class NetSwapDims(nn.Cell):
def construct(self, x, dim0, dim1):
return ops.swapdims(x, dim0=dim0, dim1=dim1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_swapdims(mode):
"""
Feature: swapdims
Description: Verify the result of swapdims
Expectation: success.
"""
ms.set_context(mode=mode)
np_array = np.random.random((3, 4, 5)).astype('float32')
np_swap_output = []
np_swap_output.append(np.swapaxes(np_array, 0, 1))
np_swap_output.append(np.swapaxes(np_array, 1, 0))
np_swap_output.append(np.swapaxes(np_array, 1, 1))
np_swap_output.append(np.swapaxes(np_array, 2, 1))
np_swap_output.append(np.swapaxes(np_array, 1, 2))
np_swap_output.append(np.swapaxes(np_array, 2, 2))
swapdims_op = NetSwapDims()
op_swap_output = []
ms_array = to_tensor(np_array)
op_swap_output.append(swapdims_op(ms_array, 0, 1))
op_swap_output.append(swapdims_op(ms_array, 1, 0))
op_swap_output.append(swapdims_op(ms_array, 1, 1))
op_swap_output.append(swapdims_op(ms_array, 2, 1))
op_swap_output.append(swapdims_op(ms_array, 1, 2))
op_swap_output.append(swapdims_op(ms_array, 2, 2))
check_all_results(np_swap_output, op_swap_output)