From b55e6d82c8689d555f8c2038f80078da8b014717 Mon Sep 17 00:00:00 2001 From: fengyihang Date: Wed, 22 Feb 2023 16:56:35 +0800 Subject: [PATCH] modify flip support --- .../python/mindspore/ops/function/nn_func.py | 17 ++++------------- tests/st/ops/test_flip.py | 2 ++ tests/st/ops/test_fliplr.py | 2 ++ tests/st/ops/test_flipud.py | 2 ++ tests/st/tensor/test_flip.py | 2 ++ tests/st/tensor/test_fliplr.py | 2 ++ tests/st/tensor/test_flipud.py | 2 ++ 7 files changed, 16 insertions(+), 13 deletions(-) diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 02ab06cf61c..8b2fefa2983 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -1844,7 +1844,7 @@ def flip(x, dims): ValueError: If `dims` is not a tuple of ints. Supported Platforms: - ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore as ms @@ -1858,16 +1858,7 @@ def flip(x, dims): [[2 1] [4 3]]] """ - _check_input_tensor("flip", x) - ndim = ops.rank(x) - shape = ops.shape(x) - dims = _check_axis_valid(dims, ndim) - if _is_shape_empty(shape): - return x - start = _get_flip_start(ndim, shape, dims) - end = _get_flip_end(ndim, shape, dims) - strides = _get_flip_strides(ndim, dims) - res = ops.strided_slice(x, start, end, strides) + res = _get_cache_prim(ops.ReverseV2)(axis=dims)(x) return res @@ -1886,7 +1877,7 @@ def flipud(x): TypeError: If the input is not a tensor. Supported Platforms: - ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore as ms @@ -1918,7 +1909,7 @@ def fliplr(x): TypeError: If the input is not a tensor. Supported Platforms: - ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore as ms diff --git a/tests/st/ops/test_flip.py b/tests/st/ops/test_flip.py index a0c2ccfac7a..128e47e8772 100644 --- a/tests/st/ops/test_flip.py +++ b/tests/st/ops/test_flip.py @@ -31,6 +31,8 @@ class Net(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.platform_arm_cpu @pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) def test_flip_normal(mode): diff --git a/tests/st/ops/test_fliplr.py b/tests/st/ops/test_fliplr.py index 61103f37811..9a262655b03 100644 --- a/tests/st/ops/test_fliplr.py +++ b/tests/st/ops/test_fliplr.py @@ -31,6 +31,8 @@ class Net(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.platform_arm_cpu @pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) def test_fliplr_normal(mode): diff --git a/tests/st/ops/test_flipud.py b/tests/st/ops/test_flipud.py index 69f62fd99be..ec9dea8319c 100644 --- a/tests/st/ops/test_flipud.py +++ b/tests/st/ops/test_flipud.py @@ -31,6 +31,8 @@ class Net(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.platform_arm_cpu @pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) def test_flipud_normal(mode): diff --git a/tests/st/tensor/test_flip.py b/tests/st/tensor/test_flip.py index 6cdb0a8bed0..f332ee53c12 100644 --- a/tests/st/tensor/test_flip.py +++ b/tests/st/tensor/test_flip.py @@ -30,6 +30,8 @@ class Net(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.platform_arm_cpu @pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) def test_flip_normal(mode): diff --git a/tests/st/tensor/test_fliplr.py b/tests/st/tensor/test_fliplr.py index 052a715ce13..6f6234fcd20 100644 --- a/tests/st/tensor/test_fliplr.py +++ b/tests/st/tensor/test_fliplr.py @@ -30,6 +30,8 @@ class Net(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.platform_arm_cpu @pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) def test_fliplr_normal(mode): diff --git a/tests/st/tensor/test_flipud.py b/tests/st/tensor/test_flipud.py index 933bd04d104..7a1b21e1fd8 100644 --- a/tests/st/tensor/test_flipud.py +++ b/tests/st/tensor/test_flipud.py @@ -30,6 +30,8 @@ class Net(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.platform_arm_cpu @pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) def test_flipud_normal(mode):