diff --git a/tests/st/ops/cpu/test_shift_op.py b/tests/st/ops/cpu/test_shift_op.py index b4e7eea7c49..daebbc91417 100644 --- a/tests/st/ops/cpu/test_shift_op.py +++ b/tests/st/ops/cpu/test_shift_op.py @@ -113,7 +113,7 @@ def compare(arr: np.ndarray, periods: int, axis: int, fill_value=np.nan): (np.bool_, True), (np.bool_, False)]) @pytest.mark.parametrize('axis', [0, 1, 2, 3]) def test_no_shift(fill_value, dtype, axis): - arr = np.random.random((40, 60, 50, 30)).astype(dtype) + arr = np.random.random((4, 6, 5, 3)).astype(dtype) compare(arr, axis=axis, periods=0, fill_value=fill_value) @@ -126,15 +126,15 @@ def test_no_shift(fill_value, dtype, axis): (np.int32, 0), (np.int32, 1), (np.int32, 5), (np.int32, -4), (np.int64, 0), (np.int64, 1), (np.int64, 5), (np.int64, -4), (np.bool_, True), (np.bool_, False)]) -@pytest.mark.parametrize('periods', [-35, 28, 90]) +@pytest.mark.parametrize('periods', [-35, 18, 25]) def test_fancy_1d(fill_value, dtype, periods): - arr = np.random.random((1, 1, 50, 1)).astype(dtype) + arr = np.random.random((1, 1, 20, 1)).astype(dtype) compare(arr, axis=2, periods=periods, fill_value=fill_value) - arr = np.random.random((70, 1, 1, 1)).astype(dtype) + arr = np.random.random((30, 1, 1, 1)).astype(dtype) compare(arr, axis=0, periods=periods, fill_value=fill_value) - arr = np.random.random((1, 1, 1, 80)).astype(dtype) + arr = np.random.random((1, 1, 1, 30)).astype(dtype) compare(arr, axis=3, periods=periods, fill_value=fill_value) @@ -148,9 +148,9 @@ def test_fancy_1d(fill_value, dtype, periods): (np.int64, 0), (np.int64, 1), (np.int64, 5), (np.int64, -4), (np.bool_, True), (np.bool_, False)]) @pytest.mark.parametrize('axis', [0, 1]) -@pytest.mark.parametrize('periods', [-24, 27, -35, 28, 100]) +@pytest.mark.parametrize('periods', [-3, 7, -5, 8, 9]) def test_2d(fill_value, dtype, axis, periods): - arr = np.random.random((30, 40)).astype(dtype) + arr = np.random.random((10, 10)).astype(dtype) compare(arr, axis=axis, periods=periods, fill_value=fill_value) @@ -166,5 +166,5 @@ def test_2d(fill_value, dtype, axis, periods): @pytest.mark.parametrize('axis', [0, 1, 2, 3]) @pytest.mark.parametrize('periods', [-30, 30, -45, 55]) def test_4d(fill_value, dtype, axis, periods): - arr = np.random.random((30, 40, 50, 60)).astype(dtype) + arr = np.random.random((30, 40, 10, 20)).astype(dtype) compare(arr, axis=axis, periods=periods, fill_value=fill_value)