fix scipy st utils to creat_sys_pos_martix

This commit is contained in:
z00512249 2022-02-24 19:52:19 +08:00
parent ac1463a192
commit 8bba3a063a
1 changed files with 2 additions and 2 deletions

View File

@ -94,8 +94,8 @@ def create_sym_pos_matrix(shape, dtype):
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
n = shape[-1]
a = (onp.random.random(shape) + onp.eye(n)).astype(dtype)
return onp.dot(a, a.T)
x = onp.random.random(shape)
return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):