forked from mindspore-Ecosystem/mindspore
!30352 move error test to daily build
Merge pull request !30352 from zhujingxuan/master
This commit is contained in:
commit
3bbf0f9dcd
|
@ -97,7 +97,7 @@ def block_diag(*arrs):
|
|||
|
||||
|
||||
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
|
||||
overwrite_b=False, debug=None, check_finite=False):
|
||||
overwrite_b=False, debug=None, check_finite=True):
|
||||
"""
|
||||
Assuming a is a batched triangular matrix, solve the equation
|
||||
|
||||
|
@ -128,7 +128,7 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
|
|||
debug (None): Not implemented now. Default: None.
|
||||
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default: False.
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor of shape :math:`(M,)` or :math:`(M, N)`,
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore import context, Tensor
|
|||
import mindspore.numpy as mnp
|
||||
from mindspore.scipy.linalg import det, solve_triangular
|
||||
from tests.st.scipy_st.utils import match_array, create_full_rank_matrix, create_sym_pos_matrix, \
|
||||
create_random_rank_matrix, match_exception_info
|
||||
create_random_rank_matrix
|
||||
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -81,7 +81,7 @@ def test_solve_triangular(n: int, dtype, lower: bool, unit_diagonal: bool, trans
|
|||
assert onp.allclose(expect, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -95,38 +95,26 @@ def test_solve_triangular_error_dims(n: int, dtype):
|
|||
"""
|
||||
a = create_random_rank_matrix((10,) * n, dtype)
|
||||
b = create_random_rank_matrix(10, dtype)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = f"For 'solve_triangular', the dimension of 'a' should be 2, but got {n}."
|
||||
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((n, n + 1), dtype)
|
||||
b = create_random_rank_matrix((10,), dtype)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = f"For 'solve_triangular', the matrix 'a' should be a square matrix like (N, N), " \
|
||||
f"but got ({n}, {n + 1})."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((10, 10), dtype)
|
||||
b = create_random_rank_matrix((11,) * n, dtype)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = f"For 'solve_triangular', the dimension of 'b' should be one of (1, 2), but got {n}."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((10, 10), dtype)
|
||||
b = create_random_rank_matrix((n,), dtype)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = f"For 'solve_triangular', the last two dimensions of 'a' and 'b' should be matched, " \
|
||||
f"but got shape of {(10, 10)} and {(n,)}. Please make sure that the shape of 'a' and 'b' be like " \
|
||||
f"(N, N) X (N, M) or (N, N) X (N)."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -138,37 +126,28 @@ def test_solve_triangular_error_tensor_dtype():
|
|||
"""
|
||||
a = create_random_rank_matrix((10, 10), onp.float16)
|
||||
b = create_random_rank_matrix((10,), onp.float16)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = f"For 'solve_triangular', the data type of 'a' should be one of " \
|
||||
f"[mindspore.int32, mindspore.int64, mindspore.float32, mindspore.float64], but got Float16."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((10, 10), onp.float32)
|
||||
b = create_random_rank_matrix((10,), onp.float16)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = f"For 'solve_triangular', the data type of 'b' should be one of " \
|
||||
f"[mindspore.int32, mindspore.int64, mindspore.float32, mindspore.float64], but got Float16."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((10, 10), onp.float32)
|
||||
b = create_random_rank_matrix((10,), onp.float64)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(Tensor(a), Tensor(b))
|
||||
msg = "For 'solve_triangular', the data type of 'a' and 'b' should be the same, " \
|
||||
"but got the data type of 'a' is Float32 and the data type of 'b' is Float64."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
|
||||
@pytest.mark.parametrize('argname, argtype', [('lower', 'bool'), ('overwrite_b', 'bool'), ('check_finite', 'bool')])
|
||||
@pytest.mark.parametrize('wrong_argvalue, wrong_argtype', [(5.0, 'float'), (None, 'NoneType'), ('test', 'str')])
|
||||
def test_solve_triangular_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype):
|
||||
@pytest.mark.parametrize('argname', ['lower', 'overwrite_b', 'check_finite'])
|
||||
@pytest.mark.parametrize('wrong_argvalue', [5.0, None, 'test'])
|
||||
def test_solve_triangular_error_type(dtype, argname, wrong_argvalue):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for batched triangular matrix solver [..., N, N]
|
||||
|
@ -178,20 +157,17 @@ def test_solve_triangular_error_type(dtype, argname, argtype, wrong_argvalue, wr
|
|||
b = create_random_rank_matrix((10,), dtype)
|
||||
|
||||
kwargs = {argname: wrong_argvalue}
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(Tensor(a), Tensor(b), **kwargs)
|
||||
msg = f"For 'solve_triangular', the type of '{argname}' should be {argtype}, " \
|
||||
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
|
||||
@pytest.mark.parametrize('wrong_argvalue, wrong_argtype', [(5.0, 'float'), (None, 'NoneType')])
|
||||
def test_solve_triangular_error_type_trans(dtype, wrong_argvalue, wrong_argtype):
|
||||
@pytest.mark.parametrize('wrong_argvalue', [5.0, None])
|
||||
def test_solve_triangular_error_type_trans(dtype, wrong_argvalue):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for batched triangular matrix solver [..., N, N]
|
||||
|
@ -200,14 +176,11 @@ def test_solve_triangular_error_type_trans(dtype, wrong_argvalue, wrong_argtype)
|
|||
a = create_random_rank_matrix((10, 10), dtype)
|
||||
b = create_random_rank_matrix((10,), dtype)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(Tensor(a), Tensor(b), trans=wrong_argvalue)
|
||||
msg = f"For 'solve_triangular', the type of 'trans' should be one of ['int', 'str'], " \
|
||||
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -222,14 +195,11 @@ def test_solve_triangular_error_value_trans(dtype, wrong_argvalue):
|
|||
a = create_random_rank_matrix((10, 10), dtype)
|
||||
b = create_random_rank_matrix((10,), dtype)
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
solve_triangular(Tensor(a), Tensor(b), trans=wrong_argvalue)
|
||||
msg = f"For 'solve_triangular', the value of 'trans' should be one of (0, 1, 2, 'N', 'T', 'C'), " \
|
||||
f"but got {wrong_argvalue}."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -241,24 +211,18 @@ def test_solve_triangular_error_tensor_type():
|
|||
"""
|
||||
a = 'test'
|
||||
b = create_random_rank_matrix((10,), onp.float32)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(a, Tensor(b))
|
||||
msg = "For 'solve_triangular', the type of 'a' should be Tensor, but got 'test' with type str."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = [1, 2, 3]
|
||||
b = create_random_rank_matrix((10,), onp.float32)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(a, Tensor(b))
|
||||
msg = "For 'solve_triangular', the type of 'a' should be Tensor, but got '[1, 2, 3]' with type list."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = (1, 2, 3)
|
||||
b = create_random_rank_matrix((10,), onp.float32)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
solve_triangular(a, Tensor(b))
|
||||
msg = "For 'solve_triangular', the type of 'a' should be Tensor, but got '(1, 2, 3)' with type tuple."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -428,16 +392,14 @@ def test_eigh_complex(n: int, data_type):
|
|||
assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64, onp.int32, onp.int64])
|
||||
@pytest.mark.parametrize('argname, argtype',
|
||||
[('lower', 'bool'), ('eigvals_only', 'bool'), ('overwrite_a', 'bool'), ('overwrite_b', 'bool'),
|
||||
('turbo', 'bool'), ('check_finite', 'bool')])
|
||||
@pytest.mark.parametrize('wrong_argvalue, wrong_argtype', [(5.0, 'float'), (None, 'NoneType')])
|
||||
def test_eigh_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype):
|
||||
@pytest.mark.parametrize('argname', ['lower', 'eigvals_only', 'overwrite_a', 'overwrite_b', 'turbo', 'check_finite'])
|
||||
@pytest.mark.parametrize('wrong_argvalue', [5.0, None])
|
||||
def test_eigh_error_type(dtype, argname, wrong_argvalue):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for batched triangular matrix solver [..., N, N]
|
||||
|
@ -447,32 +409,27 @@ def test_eigh_error_type(dtype, argname, argtype, wrong_argvalue, wrong_argtype)
|
|||
b = create_random_rank_matrix((10,), dtype)
|
||||
|
||||
kwargs = {argname: wrong_argvalue}
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
msp.linalg.eigh(Tensor(a), Tensor(b), **kwargs)
|
||||
assert str(err.value) == f"For 'eigh', the type of `{argname}` should be {argtype}, " \
|
||||
f"but got '{wrong_argvalue}' with type {wrong_argtype}."
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype, dtype_name', [(onp.float16, 'Float16'), (onp.int8, 'Int8'), (onp.int16, 'Int16')])
|
||||
def test_eigh_error_tensor_dtype(dtype, dtype_name):
|
||||
@pytest.mark.parametrize('dtype', [onp.float16, onp.int8, onp.int16])
|
||||
def test_eigh_error_tensor_dtype(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for solve_triangular for batched triangular matrix solver [..., N, N]
|
||||
Expectation: eigh raises expectated Exception
|
||||
"""
|
||||
a = create_random_rank_matrix((10, 10), dtype)
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(TypeError):
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
msg = f"For 'Eigh', the type of `A_dtype` should be in " \
|
||||
f"[mindspore.float32, mindspore.float64, mindspore.complex64, mindspore.complex128], but got {dtype_name}."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -485,20 +442,15 @@ def test_eigh_error_dims(n: int, dtype):
|
|||
Expectation: eigh raises expectated Exception
|
||||
"""
|
||||
a = create_random_rank_matrix((10,) * n, dtype)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
with pytest.raises(RuntimeError):
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
msg = f"Wrong array shape. For 'Eigh', a should be 2D, but got [{n}] dimensions."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((n, n + 1), dtype)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
with pytest.raises(RuntimeError):
|
||||
msp.linalg.eigh(Tensor(a))
|
||||
msg = f"Wrong array shape. For 'Eigh', a should be a squre matrix like [N X N], " \
|
||||
f"but got [{n} X {n + 1}]."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -510,20 +462,14 @@ def test_eigh_error_not_implemented():
|
|||
"""
|
||||
a = create_random_rank_matrix((10, 10), onp.float32)
|
||||
b = create_random_rank_matrix((10, 10), onp.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
msp.linalg.eigh(Tensor(a), Tensor(b))
|
||||
msg = "Currently only case b=None of eigh is implemented. Which means that b must be identity matrix."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
msp.linalg.eigh(Tensor(a), 42)
|
||||
msg = "Currently only case b=None of eigh is implemented. Which means that b must be identity matrix."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
msp.linalg.eigh(Tensor(a), eigvals=42)
|
||||
msg = "Currently only case eigvals=None of eighis implemented."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -23,8 +23,7 @@ from scipy.linalg import solve_triangular, eig, eigvals
|
|||
from mindspore import Tensor, context
|
||||
from mindspore.scipy.ops import EighNet, Eig, Cholesky, SolveTriangular
|
||||
from mindspore.scipy.utils import _nd_transpose
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition, \
|
||||
match_exception_info
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_random_rank_matrix, compare_eigen_decomposition
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -360,7 +359,7 @@ def test_solve_triangular_batched(n: int, batch, dtype, lower: bool, unit_diagon
|
|||
assert np.allclose(expect, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -373,30 +372,22 @@ def test_solve_triangular_error_dims():
|
|||
# matrix a is 1D
|
||||
a = create_random_rank_matrix((10,), dtype=np.float32)
|
||||
b = create_random_rank_matrix((10,), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the dimension of `a` should be at least 2, but got 1 dimensions."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
# matrix a is not square matrix
|
||||
a = create_random_rank_matrix((4, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((10,), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the last two dimensions of `a` should be the same, " \
|
||||
"but got shape of [4, 5]. Please make sure that the shape of `a` be like [..., N, N]"
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((3, 5, 4, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((3, 5, 10,), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the last two dimensions of `a` should be the same," \
|
||||
" but got shape of [3, 5, 4, 5]. Please make sure that the shape of `a` be like [..., N, N]"
|
||||
match_exception_info(err, msg)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -409,49 +400,27 @@ def test_solve_triangular_error_dims_mismatched():
|
|||
# dimension of a and b is not matched
|
||||
a = create_random_rank_matrix((3, 4, 5, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((5, 10,), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the dimension of `b` should be 'a.dim' or 'a.dim' - 1, " \
|
||||
"which is 4 or 3, but got 2 dimensions."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
# last two dimensions not matched
|
||||
a = create_random_rank_matrix((3, 4, 5, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((5, 10, 4), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the last two dimensions of `a` and `b` should be matched, " \
|
||||
"but got shape of [3, 4, 5, 5] and [5, 10, 4]. Please make sure that the shape of `a` " \
|
||||
"and `b` be like [..., N, N] X [..., N, M] or [..., N, N] X [..., N]."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((3, 4, 5, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((5, 10, 4, 1), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the last two dimensions of `a` and `b` should be matched, " \
|
||||
"but got shape of [3, 4, 5, 5] and [5, 10, 4, 1]. Please make sure that the shape of `a` " \
|
||||
"and `b` be like [..., N, N] X [..., N, M] or [..., N, N] X [..., N]."
|
||||
print(err.value)
|
||||
match_exception_info(err, msg)
|
||||
|
||||
# batch dimensions not matched
|
||||
a = create_random_rank_matrix((3, 4, 5, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((5, 10, 5), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the batch dimensions of `a` and `b` should all be the same, " \
|
||||
"but got shape of [3, 4, 5, 5] and [5, 10, 5]. Please make sure that " \
|
||||
"the shape of `a` and `b` be like [a, b, c, ..., N, N] X [a, b, c, ..., N, M] " \
|
||||
"or [a, b, c, ..., N, N] X [a, b, c, ..., N]."
|
||||
match_exception_info(err, msg)
|
||||
|
||||
a = create_random_rank_matrix((3, 4, 5, 5), dtype=np.float32)
|
||||
b = create_random_rank_matrix((5, 10, 5, 1), dtype=np.float32)
|
||||
with pytest.raises(ValueError) as err:
|
||||
with pytest.raises(ValueError):
|
||||
SolveTriangular()(Tensor(a), Tensor(b))
|
||||
msg = "For 'SolveTriangular', the batch dimensions of `a` and `b` should all be the same, " \
|
||||
"but got shape of [3, 4, 5, 5] and [5, 10, 5, 1]. Please make sure that " \
|
||||
"the shape of `a` and `b` be like [a, b, c, ..., N, N] X [a, b, c, ..., N, M] " \
|
||||
"or [a, b, c, ..., N, N] X [a, b, c, ..., N]."
|
||||
match_exception_info(err, msg)
|
||||
|
|
|
@ -145,11 +145,6 @@ def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
|
|||
return difference
|
||||
|
||||
|
||||
def match_exception_info(err, expected_str):
|
||||
err_str = str(err.value)
|
||||
assert expected_str in err_str
|
||||
|
||||
|
||||
def compare_eigen_decomposition(src_res, tgt_res, compute_v, rtol, atol):
|
||||
def my_argsort(w):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue