!27985 reduce the size of shift test case

Merge pull request !27985 from zhujingxuan/reduce_test_case
This commit is contained in:
i-robot 2021-12-21 12:16:02 +00:00 committed by Gitee
commit 60a36bbf8f
1 changed files with 8 additions and 8 deletions

View File

@ -113,7 +113,7 @@ def compare(arr: np.ndarray, periods: int, axis: int, fill_value=np.nan):
(np.bool_, True), (np.bool_, False)])
@pytest.mark.parametrize('axis', [0, 1, 2, 3])
def test_no_shift(fill_value, dtype, axis):
arr = np.random.random((40, 60, 50, 30)).astype(dtype)
arr = np.random.random((4, 6, 5, 3)).astype(dtype)
compare(arr, axis=axis, periods=0, fill_value=fill_value)
@ -126,15 +126,15 @@ def test_no_shift(fill_value, dtype, axis):
(np.int32, 0), (np.int32, 1), (np.int32, 5), (np.int32, -4),
(np.int64, 0), (np.int64, 1), (np.int64, 5), (np.int64, -4),
(np.bool_, True), (np.bool_, False)])
@pytest.mark.parametrize('periods', [-35, 28, 90])
@pytest.mark.parametrize('periods', [-35, 18, 25])
def test_fancy_1d(fill_value, dtype, periods):
arr = np.random.random((1, 1, 50, 1)).astype(dtype)
arr = np.random.random((1, 1, 20, 1)).astype(dtype)
compare(arr, axis=2, periods=periods, fill_value=fill_value)
arr = np.random.random((70, 1, 1, 1)).astype(dtype)
arr = np.random.random((30, 1, 1, 1)).astype(dtype)
compare(arr, axis=0, periods=periods, fill_value=fill_value)
arr = np.random.random((1, 1, 1, 80)).astype(dtype)
arr = np.random.random((1, 1, 1, 30)).astype(dtype)
compare(arr, axis=3, periods=periods, fill_value=fill_value)
@ -148,9 +148,9 @@ def test_fancy_1d(fill_value, dtype, periods):
(np.int64, 0), (np.int64, 1), (np.int64, 5), (np.int64, -4),
(np.bool_, True), (np.bool_, False)])
@pytest.mark.parametrize('axis', [0, 1])
@pytest.mark.parametrize('periods', [-24, 27, -35, 28, 100])
@pytest.mark.parametrize('periods', [-3, 7, -5, 8, 9])
def test_2d(fill_value, dtype, axis, periods):
arr = np.random.random((30, 40)).astype(dtype)
arr = np.random.random((10, 10)).astype(dtype)
compare(arr, axis=axis, periods=periods, fill_value=fill_value)
@ -166,5 +166,5 @@ def test_2d(fill_value, dtype, axis, periods):
@pytest.mark.parametrize('axis', [0, 1, 2, 3])
@pytest.mark.parametrize('periods', [-30, 30, -45, 55])
def test_4d(fill_value, dtype, axis, periods):
arr = np.random.random((30, 40, 50, 60)).astype(dtype)
arr = np.random.random((30, 40, 10, 20)).astype(dtype)
compare(arr, axis=axis, periods=periods, fill_value=fill_value)