!29277 fix linalg doc description

Merge pull request !29277 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2022-01-19 06:12:33 +00:00 committed by Gitee
commit 28e9fd9821
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 19 additions and 0 deletions

View File

@ -215,6 +215,8 @@ def inv(a, overwrite_a=False, check_finite=True):
[[1.0000000e+00 0.0000000e+00]
[8.8817842e-16 1.0000000e+00]]
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(check_finite, "check_finite", bool)
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
@ -270,6 +272,8 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
[ 1. 5. 2.2933078 0.8559526 ]
[ 5. 1. 2. 1.5541857 ]]
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(check_finite, "check_finite", bool)
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float64)
a_shape = a.shape
@ -320,6 +324,8 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
[[1. 0.]
[2. 1.]]
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(check_finite, "check_finite", bool)
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float64)
@ -363,6 +369,8 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
>>> print(x)
[-0.01749266 0.11953348 0.01166185 0.15743434]
"""
_type_check(overwrite_b, "overwrite_b", bool)
_type_check(check_finite, "check_finite", bool)
(c, lower) = c_and_lower
cholesky_solver_net = CholeskySolver(lower=lower)
x = cholesky_solver_net(c, b)
@ -440,6 +448,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
>>> print(mnp.sum(mnp.dot(A, v) - mnp.dot(v, mnp.diag(w))) < 1e-10)
True
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(overwrite_b, "overwrite_b", bool)
_type_check(check_finite, "check_finite", bool)
eigh_net = EighNet(not eigvals_only, lower=True)
return eigh_net(a)
@ -557,6 +568,8 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
>>> print(piv)
[2 2 3 3]
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(check_finite, "check_finite", bool)
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
@ -627,6 +640,8 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
[ 0. 0. -1.03999996 3.07999992]
[ 0. -0. -0. 7.46153831]]
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(check_finite, "check_finite", bool)
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
msp_lu = LU()
@ -680,6 +695,8 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
>>> print(lu_solve((lu, piv), b))
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
"""
_type_check(overwrite_b, "overwrite_b", bool)
_type_check(check_finite, "check_finite", bool)
m_lu, pivots = lu_and_piv
# 1. check shape
check_lu_shape(m_lu, b)
@ -744,6 +761,8 @@ def det(a, overwrite_a=False, check_finite=True):
>>> print(det(a))
3.0
"""
_type_check(overwrite_a, "overwrite_a", bool)
_type_check(check_finite, "check_finite", bool)
# special case
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
return _det_2x2(a)