!49249 Modify flip support

Merge pull request !49249 from 冯一航/modify_flip_support
This commit is contained in:
i-robot 2023-02-23 03:00:41 +00:00 committed by Gitee
commit 9a81a0e969
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 16 additions and 13 deletions

View File

@ -1844,7 +1844,7 @@ def flip(x, dims):
ValueError: If `dims` is not a tuple of ints.
Supported Platforms:
``GPU`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
@ -1858,16 +1858,7 @@ def flip(x, dims):
[[2 1]
[4 3]]]
"""
_check_input_tensor("flip", x)
ndim = ops.rank(x)
shape = ops.shape(x)
dims = _check_axis_valid(dims, ndim)
if _is_shape_empty(shape):
return x
start = _get_flip_start(ndim, shape, dims)
end = _get_flip_end(ndim, shape, dims)
strides = _get_flip_strides(ndim, dims)
res = ops.strided_slice(x, start, end, strides)
res = _get_cache_prim(ops.ReverseV2)(axis=dims)(x)
return res
@ -1886,7 +1877,7 @@ def flipud(x):
TypeError: If the input is not a tensor.
Supported Platforms:
``GPU`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
@ -1918,7 +1909,7 @@ def fliplr(x):
TypeError: If the input is not a tensor.
Supported Platforms:
``GPU`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms

View File

@ -31,6 +31,8 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flip_normal(mode):

View File

@ -31,6 +31,8 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fliplr_normal(mode):

View File

@ -31,6 +31,8 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flipud_normal(mode):

View File

@ -30,6 +30,8 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flip_normal(mode):

View File

@ -30,6 +30,8 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_fliplr_normal(mode):

View File

@ -30,6 +30,8 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_flipud_normal(mode):