forked from mindspore-Ecosystem/mindspore
!29277 fix linalg doc description
Merge pull request !29277 from zhuzhongrui/pub_master
This commit is contained in:
commit
28e9fd9821
|
@ -215,6 +215,8 @@ def inv(a, overwrite_a=False, check_finite=True):
|
||||||
[[1.0000000e+00 0.0000000e+00]
|
[[1.0000000e+00 0.0000000e+00]
|
||||||
[8.8817842e-16 1.0000000e+00]]
|
[8.8817842e-16 1.0000000e+00]]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
if F.dtype(a) not in float_types:
|
if F.dtype(a) not in float_types:
|
||||||
a = F.cast(a, mstype.float32)
|
a = F.cast(a, mstype.float32)
|
||||||
|
|
||||||
|
@ -270,6 +272,8 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
||||||
[ 1. 5. 2.2933078 0.8559526 ]
|
[ 1. 5. 2.2933078 0.8559526 ]
|
||||||
[ 5. 1. 2. 1.5541857 ]]
|
[ 5. 1. 2. 1.5541857 ]]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
if F.dtype(a) not in float_types:
|
if F.dtype(a) not in float_types:
|
||||||
a = F.cast(a, mstype.float64)
|
a = F.cast(a, mstype.float64)
|
||||||
a_shape = a.shape
|
a_shape = a.shape
|
||||||
|
@ -320,6 +324,8 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
||||||
[[1. 0.]
|
[[1. 0.]
|
||||||
[2. 1.]]
|
[2. 1.]]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
if F.dtype(a) not in float_types:
|
if F.dtype(a) not in float_types:
|
||||||
a = F.cast(a, mstype.float64)
|
a = F.cast(a, mstype.float64)
|
||||||
|
|
||||||
|
@ -363,6 +369,8 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
||||||
>>> print(x)
|
>>> print(x)
|
||||||
[-0.01749266 0.11953348 0.01166185 0.15743434]
|
[-0.01749266 0.11953348 0.01166185 0.15743434]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_b, "overwrite_b", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
(c, lower) = c_and_lower
|
(c, lower) = c_and_lower
|
||||||
cholesky_solver_net = CholeskySolver(lower=lower)
|
cholesky_solver_net = CholeskySolver(lower=lower)
|
||||||
x = cholesky_solver_net(c, b)
|
x = cholesky_solver_net(c, b)
|
||||||
|
@ -440,6 +448,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
||||||
>>> print(mnp.sum(mnp.dot(A, v) - mnp.dot(v, mnp.diag(w))) < 1e-10)
|
>>> print(mnp.sum(mnp.dot(A, v) - mnp.dot(v, mnp.diag(w))) < 1e-10)
|
||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(overwrite_b, "overwrite_b", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
eigh_net = EighNet(not eigvals_only, lower=True)
|
eigh_net = EighNet(not eigvals_only, lower=True)
|
||||||
return eigh_net(a)
|
return eigh_net(a)
|
||||||
|
|
||||||
|
@ -557,6 +568,8 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||||
>>> print(piv)
|
>>> print(piv)
|
||||||
[2 2 3 3]
|
[2 2 3 3]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
if F.dtype(a) not in float_types:
|
if F.dtype(a) not in float_types:
|
||||||
a = F.cast(a, mstype.float32)
|
a = F.cast(a, mstype.float32)
|
||||||
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
|
if len(a.shape) < 2 or (a.shape[-1] != a.shape[-2]):
|
||||||
|
@ -627,6 +640,8 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
||||||
[ 0. 0. -1.03999996 3.07999992]
|
[ 0. 0. -1.03999996 3.07999992]
|
||||||
[ 0. -0. -0. 7.46153831]]
|
[ 0. -0. -0. 7.46153831]]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
if F.dtype(a) not in float_types:
|
if F.dtype(a) not in float_types:
|
||||||
a = F.cast(a, mstype.float32)
|
a = F.cast(a, mstype.float32)
|
||||||
msp_lu = LU()
|
msp_lu = LU()
|
||||||
|
@ -680,6 +695,8 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||||
>>> print(lu_solve((lu, piv), b))
|
>>> print(lu_solve((lu, piv), b))
|
||||||
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
|
[ 0.05154639, -0.08247423, 0.08247423, 0.09278351]
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_b, "overwrite_b", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
m_lu, pivots = lu_and_piv
|
m_lu, pivots = lu_and_piv
|
||||||
# 1. check shape
|
# 1. check shape
|
||||||
check_lu_shape(m_lu, b)
|
check_lu_shape(m_lu, b)
|
||||||
|
@ -744,6 +761,8 @@ def det(a, overwrite_a=False, check_finite=True):
|
||||||
>>> print(det(a))
|
>>> print(det(a))
|
||||||
3.0
|
3.0
|
||||||
"""
|
"""
|
||||||
|
_type_check(overwrite_a, "overwrite_a", bool)
|
||||||
|
_type_check(check_finite, "check_finite", bool)
|
||||||
# special case
|
# special case
|
||||||
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
|
if a.ndim >= 2 and a.shape[-1] == 2 and a.shape[-2] == 2:
|
||||||
return _det_2x2(a)
|
return _det_2x2(a)
|
||||||
|
|
Loading…
Reference in New Issue