diff --git a/tests/st/scipy_st/utils.py b/tests/st/scipy_st/utils.py index 71166aca23d..73e2f935cc5 100644 --- a/tests/st/scipy_st/utils.py +++ b/tests/st/scipy_st/utils.py @@ -94,8 +94,8 @@ def create_sym_pos_matrix(shape, dtype): 'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape) n = shape[-1] - a = (onp.random.random(shape) + onp.eye(n)).astype(dtype) - return onp.dot(a, a.T) + x = onp.random.random(shape) + return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype) def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):