forked from mindspore-Ecosystem/mindspore
!27985 reduce the size of shift test case
Merge pull request !27985 from zhujingxuan/reduce_test_case
This commit is contained in:
commit
60a36bbf8f
|
@ -113,7 +113,7 @@ def compare(arr: np.ndarray, periods: int, axis: int, fill_value=np.nan):
|
|||
(np.bool_, True), (np.bool_, False)])
|
||||
@pytest.mark.parametrize('axis', [0, 1, 2, 3])
|
||||
def test_no_shift(fill_value, dtype, axis):
|
||||
arr = np.random.random((40, 60, 50, 30)).astype(dtype)
|
||||
arr = np.random.random((4, 6, 5, 3)).astype(dtype)
|
||||
compare(arr, axis=axis, periods=0, fill_value=fill_value)
|
||||
|
||||
|
||||
|
@ -126,15 +126,15 @@ def test_no_shift(fill_value, dtype, axis):
|
|||
(np.int32, 0), (np.int32, 1), (np.int32, 5), (np.int32, -4),
|
||||
(np.int64, 0), (np.int64, 1), (np.int64, 5), (np.int64, -4),
|
||||
(np.bool_, True), (np.bool_, False)])
|
||||
@pytest.mark.parametrize('periods', [-35, 28, 90])
|
||||
@pytest.mark.parametrize('periods', [-35, 18, 25])
|
||||
def test_fancy_1d(fill_value, dtype, periods):
|
||||
arr = np.random.random((1, 1, 50, 1)).astype(dtype)
|
||||
arr = np.random.random((1, 1, 20, 1)).astype(dtype)
|
||||
compare(arr, axis=2, periods=periods, fill_value=fill_value)
|
||||
|
||||
arr = np.random.random((70, 1, 1, 1)).astype(dtype)
|
||||
arr = np.random.random((30, 1, 1, 1)).astype(dtype)
|
||||
compare(arr, axis=0, periods=periods, fill_value=fill_value)
|
||||
|
||||
arr = np.random.random((1, 1, 1, 80)).astype(dtype)
|
||||
arr = np.random.random((1, 1, 1, 30)).astype(dtype)
|
||||
compare(arr, axis=3, periods=periods, fill_value=fill_value)
|
||||
|
||||
|
||||
|
@ -148,9 +148,9 @@ def test_fancy_1d(fill_value, dtype, periods):
|
|||
(np.int64, 0), (np.int64, 1), (np.int64, 5), (np.int64, -4),
|
||||
(np.bool_, True), (np.bool_, False)])
|
||||
@pytest.mark.parametrize('axis', [0, 1])
|
||||
@pytest.mark.parametrize('periods', [-24, 27, -35, 28, 100])
|
||||
@pytest.mark.parametrize('periods', [-3, 7, -5, 8, 9])
|
||||
def test_2d(fill_value, dtype, axis, periods):
|
||||
arr = np.random.random((30, 40)).astype(dtype)
|
||||
arr = np.random.random((10, 10)).astype(dtype)
|
||||
compare(arr, axis=axis, periods=periods, fill_value=fill_value)
|
||||
|
||||
|
||||
|
@ -166,5 +166,5 @@ def test_2d(fill_value, dtype, axis, periods):
|
|||
@pytest.mark.parametrize('axis', [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize('periods', [-30, 30, -45, 55])
|
||||
def test_4d(fill_value, dtype, axis, periods):
|
||||
arr = np.random.random((30, 40, 50, 60)).astype(dtype)
|
||||
arr = np.random.random((30, 40, 10, 20)).astype(dtype)
|
||||
compare(arr, axis=axis, periods=periods, fill_value=fill_value)
|
||||
|
|
Loading…
Reference in New Issue