!47794 add api ops.swapaxes and ops.swapdims
Merge pull request !47794 from GuoZhibin/add_ops_swapaxes
This commit is contained in:
commit
665aae0b42
|
@ -505,6 +505,8 @@ Array操作
|
|||
mindspore.ops.stack
|
||||
mindspore.ops.strided_slice
|
||||
mindspore.ops.sum
|
||||
mindspore.ops.swapaxes
|
||||
mindspore.ops.swapdims
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.tensor_scatter_div
|
||||
mindspore.ops.tensor_scatter_max
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.swapaxes
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.swapaxes(x, axis0, axis1)
|
||||
|
||||
交换Tensor的两个维度。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。
|
||||
- **axis0** (int) - 第一个维度。
|
||||
- **axis1** (int) - 第二个维度。
|
||||
|
||||
返回:
|
||||
转化后的Tensor,与输入具有相同的数据类型。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor类型。
|
||||
- **TypeError** - `axis0` 或 `axis1` 不是整数。
|
||||
- **ValueError** - `axis0` 或 `axis1` 不在 `[-ndim, ndim-1]` 范围内。
|
|
@ -0,0 +1,20 @@
|
|||
mindspore.ops.swapdims
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.swapdims(x, dim0, dim1)
|
||||
|
||||
交换Tensor的两个维度。
|
||||
该函数和 :func:`mindspore.ops.swapaxes` 功能一致。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。
|
||||
- **dim0** (int) - 第一个维度。
|
||||
- **dim1** (int) - 第二个维度。
|
||||
|
||||
返回:
|
||||
转化后的Tensor,与输入具有相同的数据类型。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor类型。
|
||||
- **TypeError** - `dim0` 或 `dim1` 不是整数。
|
||||
- **ValueError** - `dim0` 或 `dim1` 不在 `[-ndim, ndim-1]` 范围内。
|
|
@ -505,6 +505,8 @@ Array Operation
|
|||
mindspore.ops.stack
|
||||
mindspore.ops.strided_slice
|
||||
mindspore.ops.sum
|
||||
mindspore.ops.swapaxes
|
||||
mindspore.ops.swapdims
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.tensor_scatter_min
|
||||
mindspore.ops.tensor_scatter_max
|
||||
|
|
|
@ -848,18 +848,18 @@ class Validator:
|
|||
|
||||
@staticmethod
|
||||
def check_swapaxes_axis(axes, ndim):
|
||||
"""Check all the axes argument for tensor.swapaxes"""
|
||||
"""Check all the axes argument for ops.swapaxes"""
|
||||
if isinstance(axes, int):
|
||||
Validator.check_axis_in_range(axes, ndim)
|
||||
return axes % ndim
|
||||
if isinstance(axes, (tuple, list)):
|
||||
for axis in axes:
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError(f"For Tensor.swapaxes, the axis argument must be integer, but got {type(axis)}.")
|
||||
raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
|
||||
Validator.check_axis_in_range(axis, ndim)
|
||||
axes = tuple(map(lambda x: x % ndim, axes))
|
||||
return axes
|
||||
raise TypeError(f"For Tensor.swapaxes, the argument 'axes' must be integer, list or tuple for check, "
|
||||
raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, "
|
||||
f"but got {type(axes)}.")
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1819,22 +1819,7 @@ class Tensor(Tensor_):
|
|||
(4,3,2)
|
||||
"""
|
||||
self._init_check()
|
||||
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
|
||||
|
||||
if axis1 == axis2:
|
||||
return self
|
||||
if axis1 > axis2:
|
||||
axis1, axis2 = axis2, axis1
|
||||
|
||||
perm = tuple(range(0, self.ndim))
|
||||
if axis2 + 1 < self.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
|
||||
return tensor_operator_registry.get('transpose')()(self, new_perm)
|
||||
return tensor_operator_registry.get('swapaxes')(self, axis1, axis2)
|
||||
|
||||
def squeeze(self, axis=None):
|
||||
"""
|
||||
|
|
|
@ -141,6 +141,8 @@ from .array_func import (
|
|||
lstsq,
|
||||
mvlgamma,
|
||||
argsort,
|
||||
swapaxes,
|
||||
swapdims,
|
||||
sequence_mask,
|
||||
repeat_elements,
|
||||
repeat_interleave,
|
||||
|
|
|
@ -6069,6 +6069,94 @@ def count_nonzero(x, dims=None):
|
|||
return count_nonzero_(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_swapaxes_axis(axes, ndim):
|
||||
return validator.check_swapaxes_axis(axes, ndim)
|
||||
|
||||
|
||||
def swapaxes(x, axis0, axis1):
|
||||
'''
|
||||
Interchange two axes of a tensor.
|
||||
|
||||
Args:
|
||||
x(Tensor): Input tensor.
|
||||
axis0 (int): First axis.
|
||||
axis1 (int): Second axis.
|
||||
|
||||
Returns:
|
||||
Transposed tensor, has the same data type as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument `x` is not Tensor.
|
||||
TypeError: If `axis0` or `axis1` is not integer.
|
||||
ValueError: If `axis0` or `axis1` is not in the range of :math:`[-ndim, ndim-1]`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
|
||||
>>> output = ops.swapaxes(x, 0, 2)
|
||||
>>> print(output.shape)
|
||||
(4,3,2)
|
||||
'''
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f'For ops.swapaxes, parameter `x` must be Tensor, but got {type(x)}')
|
||||
|
||||
axis0, axis1 = _check_swapaxes_axis((axis0, axis1), x.ndim)
|
||||
if axis0 == axis1:
|
||||
return x
|
||||
if axis0 > axis1:
|
||||
axis0, axis1 = axis1, axis0
|
||||
|
||||
perm = tuple(builtins.range(0, x.ndim))
|
||||
if axis1 + 1 < x.ndim:
|
||||
new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
|
||||
perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1] + perm[axis1 + 1:]
|
||||
else:
|
||||
new_perm = perm[0:axis0] + perm[axis1:axis1 + 1] + \
|
||||
perm[axis0 + 1:axis1] + perm[axis0:axis0 + 1]
|
||||
|
||||
return _get_cache_prim(P.Transpose)()(x, new_perm)
|
||||
|
||||
|
||||
def swapdims(x, dim0, dim1):
|
||||
'''
|
||||
Interchange two dims of a tensor.
|
||||
This function is equivalent to :func:`mindspore.ops.swapaxes` function.
|
||||
|
||||
Args:
|
||||
x(Tensor): Input tensor.
|
||||
dim0 (int): First dim.
|
||||
dim1 (int): Second dim.
|
||||
|
||||
Returns:
|
||||
Transposed tensor, has the same data type as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument `x` is not Tensor.
|
||||
TypeError: If `dim0` or `dim1` is not integer.
|
||||
ValueError: If `dim0` or `dim1` is not in the range of :math:`[-ndim, ndim-1]`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
|
||||
>>> output = ops.swapdims(x, 0, 2)
|
||||
>>> print(output.shape)
|
||||
(4,3,2)
|
||||
'''
|
||||
return F.swapaxes(x, dim0, dim1)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_int(arg_value, arg_name, op_name):
|
||||
arg_value = validator.check_is_int(arg_value, arg_name, op_name)
|
||||
|
@ -6375,6 +6463,8 @@ __all__ = [
|
|||
'diagonal',
|
||||
'lstsq',
|
||||
'mvlgamma',
|
||||
'swapaxes',
|
||||
'swapdims',
|
||||
'argsort',
|
||||
'sequence_mask',
|
||||
'repeat_elements',
|
||||
|
|
|
@ -400,6 +400,7 @@ tensor_operator_registry.register('deg2rad', deg2rad)
|
|||
tensor_operator_registry.register('copysign', copysign)
|
||||
tensor_operator_registry.register('roll', Roll)
|
||||
tensor_operator_registry.register('rot90', rot90)
|
||||
tensor_operator_registry.register('swapaxes', swapaxes)
|
||||
tensor_operator_registry.register('repeat_elements', repeat_elements)
|
||||
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import ops
|
||||
import mindspore.nn as nn
|
||||
from tests.st.numpy_native.utils import check_all_results, to_tensor
|
||||
|
||||
|
||||
class NetSwapAxes(nn.Cell):
|
||||
def construct(self, x, axis0, axis1):
|
||||
return ops.swapaxes(x, axis0=axis0, axis1=axis1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_swapaxes(mode):
|
||||
"""
|
||||
Feature: swapaxes
|
||||
Description: Verify the result of swapaxes
|
||||
Expectation: success.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
np_array = np.random.random((3, 4, 5)).astype('float32')
|
||||
|
||||
np_swap_output = []
|
||||
np_swap_output.append(np.swapaxes(np_array, 0, 1))
|
||||
np_swap_output.append(np.swapaxes(np_array, 1, 0))
|
||||
np_swap_output.append(np.swapaxes(np_array, 1, 1))
|
||||
np_swap_output.append(np.swapaxes(np_array, 2, 1))
|
||||
np_swap_output.append(np.swapaxes(np_array, 1, 2))
|
||||
np_swap_output.append(np.swapaxes(np_array, 2, 2))
|
||||
|
||||
swapaxes_op = NetSwapAxes()
|
||||
op_swap_output = []
|
||||
ms_array = to_tensor(np_array)
|
||||
op_swap_output.append(swapaxes_op(ms_array, 0, 1))
|
||||
op_swap_output.append(swapaxes_op(ms_array, 1, 0))
|
||||
op_swap_output.append(swapaxes_op(ms_array, 1, 1))
|
||||
op_swap_output.append(swapaxes_op(ms_array, 2, 1))
|
||||
op_swap_output.append(swapaxes_op(ms_array, 1, 2))
|
||||
op_swap_output.append(swapaxes_op(ms_array, 2, 2))
|
||||
|
||||
check_all_results(np_swap_output, op_swap_output)
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import ops
|
||||
import mindspore.nn as nn
|
||||
from tests.st.numpy_native.utils import check_all_results, to_tensor
|
||||
|
||||
|
||||
class NetSwapDims(nn.Cell):
|
||||
def construct(self, x, dim0, dim1):
|
||||
return ops.swapdims(x, dim0=dim0, dim1=dim1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_swapdims(mode):
|
||||
"""
|
||||
Feature: swapdims
|
||||
Description: Verify the result of swapdims
|
||||
Expectation: success.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
np_array = np.random.random((3, 4, 5)).astype('float32')
|
||||
|
||||
np_swap_output = []
|
||||
np_swap_output.append(np.swapaxes(np_array, 0, 1))
|
||||
np_swap_output.append(np.swapaxes(np_array, 1, 0))
|
||||
np_swap_output.append(np.swapaxes(np_array, 1, 1))
|
||||
np_swap_output.append(np.swapaxes(np_array, 2, 1))
|
||||
np_swap_output.append(np.swapaxes(np_array, 1, 2))
|
||||
np_swap_output.append(np.swapaxes(np_array, 2, 2))
|
||||
|
||||
swapdims_op = NetSwapDims()
|
||||
op_swap_output = []
|
||||
ms_array = to_tensor(np_array)
|
||||
op_swap_output.append(swapdims_op(ms_array, 0, 1))
|
||||
op_swap_output.append(swapdims_op(ms_array, 1, 0))
|
||||
op_swap_output.append(swapdims_op(ms_array, 1, 1))
|
||||
op_swap_output.append(swapdims_op(ms_array, 2, 1))
|
||||
op_swap_output.append(swapdims_op(ms_array, 1, 2))
|
||||
op_swap_output.append(swapdims_op(ms_array, 2, 2))
|
||||
|
||||
check_all_results(np_swap_output, op_swap_output)
|
Loading…
Reference in New Issue