forked from mindspore-Ecosystem/mindspore
!31798 Fix python code check for scipy module.
Merge pull request !31798 from hezhenhao1/fix_code_check
This commit is contained in:
commit
12f6fb3d7f
|
@ -155,57 +155,6 @@ class Cholesky(PrimitiveWithInfer):
|
|||
return output
|
||||
|
||||
|
||||
class CholeskySolve(PrimitiveWithInfer):
|
||||
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower : bool, optional
|
||||
Whether to compute the upper or lower triangular Cholesky factorization
|
||||
(Default: upper-triangular)
|
||||
b : array
|
||||
Right-hand side
|
||||
|
||||
Inputs:
|
||||
- **A** (Tensor) - A matrix of shape :math:`(M, M)` to be decomposed.
|
||||
- **b** (Tensor) - A Tensor of shape :math:`(M,)` or :math:`(..., M)`.
|
||||
Right-hand side matrix in :math:`A x = b`.
|
||||
Returns
|
||||
-------
|
||||
x : array
|
||||
The solution to the system A x = b
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.ops import CholeskySolve
|
||||
>>> from mindspore.scipy.linalg import cho_factor
|
||||
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]], dtype=onp.float32))
|
||||
>>> b = Tensor(onp.array([1.0, 1.0, 1.0, 1.0], dtype=onp.float32))
|
||||
>>> c, lower = cho_factor(A)
|
||||
>>> cholesky_solver = CholeskySolve(lower=lower)
|
||||
>>> x = cholesky_solver(c, b)
|
||||
>>> print(x)
|
||||
[-0.01749266 0.11953348 0.01166185 0.15743434]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=False):
|
||||
super().__init__(name="CholeskySolve")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['A', 'b'], outputs=['y'])
|
||||
|
||||
def __infer__(self, A, b):
|
||||
b_shape = b['shape']
|
||||
a_dtype = A['dtype']
|
||||
return {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': a_dtype,
|
||||
'value': None
|
||||
}
|
||||
|
||||
|
||||
class Eigh(PrimitiveWithInfer):
|
||||
"""
|
||||
Eigh decomposition(Symmetric matrix)
|
||||
|
@ -309,28 +258,6 @@ class LU(PrimitiveWithInfer):
|
|||
return output
|
||||
|
||||
|
||||
class LUSolver(PrimitiveWithInfer):
|
||||
"""
|
||||
LUSolver for Ax = b
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, trans: str):
|
||||
super().__init__(name="LUSolver")
|
||||
self.init_prim_io_names(inputs=['a', 'b'], outputs=['output'])
|
||||
self.trans = validator.check_value_type("trans", trans, [str], self.name)
|
||||
|
||||
def __infer__(self, a, b):
|
||||
b_shape = list(b['shape'])
|
||||
a_dtype = a['dtype']
|
||||
output = {
|
||||
'shape': tuple(b_shape),
|
||||
'dtype': a_dtype,
|
||||
'value': None
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class MatrixSetDiag(PrimitiveWithInfer):
|
||||
"""
|
||||
Inner API to set a [..., M, N] matrix's diagonals by range[k[0], k[1]].
|
||||
|
|
|
@ -14,11 +14,9 @@
|
|||
# ============================================================================
|
||||
"""BFGS"""
|
||||
from typing import NamedTuple
|
||||
|
||||
from ... import nn
|
||||
from ... import numpy as mnp
|
||||
from ...common import Tensor
|
||||
|
||||
from .line_search import LineSearch
|
||||
from ..utils import _to_scalar, _to_tensor, grad, _norm
|
||||
|
||||
|
@ -82,7 +80,7 @@ class MinimizeBfgs(nn.Cell):
|
|||
maxiter = mnp.size(x0) * 200
|
||||
|
||||
d = x0.shape[0]
|
||||
I = mnp.eye(d, dtype=x0.dtype)
|
||||
identity = mnp.eye(d, dtype=x0.dtype)
|
||||
f_0 = self.func(x0)
|
||||
g_0 = grad(self.func)(x0)
|
||||
|
||||
|
@ -96,7 +94,7 @@ class MinimizeBfgs(nn.Cell):
|
|||
"x_k": x0,
|
||||
"f_k": f_0,
|
||||
"g_k": g_0,
|
||||
"H_k": I,
|
||||
"H_k": identity,
|
||||
"old_old_fval": f_0 + _norm(g_0) / 2,
|
||||
"status": _INT_ZERO,
|
||||
"line_search_status": _INT_ZERO
|
||||
|
@ -136,13 +134,13 @@ class MinimizeBfgs(nn.Cell):
|
|||
rho_k = mnp.reciprocal(mnp.dot(y_k, s_k))
|
||||
sy_k = mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(y_k, axis=0)
|
||||
term1 = rho_k * sy_k
|
||||
sy_k_2 = mnp.expand_dims(y_k, axis=1) * mnp.expand_dims(s_k, axis=0)
|
||||
term2 = rho_k * sy_k_2
|
||||
term3 = mnp.matmul(I - term1, state["H_k"])
|
||||
term4 = mnp.matmul(term3, I - term2)
|
||||
ys_k = mnp.expand_dims(y_k, axis=1) * mnp.expand_dims(s_k, axis=0)
|
||||
term2 = rho_k * ys_k
|
||||
term3 = mnp.matmul(identity - term1, state["H_k"])
|
||||
term4 = mnp.matmul(term3, identity - term2)
|
||||
term5 = rho_k * (mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(s_k, axis=0))
|
||||
H_kp1 = term4 + term5
|
||||
state["H_k"] = H_kp1
|
||||
hess_kp1 = term4 + term5
|
||||
state["H_k"] = hess_kp1
|
||||
|
||||
# next iteration
|
||||
state["k"] = state["k"] + 1
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
# ============================================================================
|
||||
"""line search"""
|
||||
from typing import NamedTuple
|
||||
|
||||
from ... import nn
|
||||
from ... import numpy as mnp
|
||||
from ...common import dtype as mstype
|
||||
from ...common import Tensor
|
||||
|
||||
from ..utils import _to_scalar, _to_tensor, grad
|
||||
|
||||
|
||||
|
@ -52,7 +50,6 @@ def _cubicmin(a, fa, fpa, b, fb, c, fc):
|
|||
"""Finds the minimizer for a cubic polynomial that goes through the
|
||||
points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa.
|
||||
"""
|
||||
C = fpa
|
||||
db = b - a
|
||||
dc = c - a
|
||||
denom = (db * dc) ** 2 * (db - dc)
|
||||
|
@ -64,13 +61,12 @@ def _cubicmin(a, fa, fpa, b, fb, c, fc):
|
|||
d1[1, 1] = db ** 3
|
||||
|
||||
d2 = mnp.zeros((2,))
|
||||
d2[0] = fb - fa - C * db
|
||||
d2[1] = fc - fa - C * dc
|
||||
d2[0] = fb - fa - fpa * db
|
||||
d2[1] = fc - fa - fpa * dc
|
||||
|
||||
A, B = mnp.dot(d1, d2.flatten()) / denom
|
||||
|
||||
radical = B * B - 3. * A * C
|
||||
xmin = a + (-B + mnp.sqrt(radical)) / (3. * A)
|
||||
a2, b2 = mnp.dot(d1, d2) / denom
|
||||
radical = b2 * b2 - 3. * a2 * fpa
|
||||
xmin = a + (-b2 + mnp.sqrt(radical)) / (3. * a2)
|
||||
return xmin
|
||||
|
||||
|
||||
|
@ -78,11 +74,9 @@ def _quadmin(a, fa, fpa, b, fb):
|
|||
"""Finds the minimizer for a quadratic polynomial that goes through
|
||||
the points (a,fa), (b,fb) with derivative at a of fpa.
|
||||
"""
|
||||
D = fa
|
||||
C = fpa
|
||||
db = b - a
|
||||
B = (fb - D - C * db) / (db ** 2)
|
||||
xmin = a - C / (2. * B)
|
||||
b2 = (fb - fa - fpa * db) / (db ** 2)
|
||||
xmin = a - fpa / (2. * b2)
|
||||
return xmin
|
||||
|
||||
|
||||
|
@ -179,8 +173,7 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0,
|
|||
state["phi_low"] = mnp.where(j_to_low, phi_j, state["phi_low"])
|
||||
state["dphi_low"] = mnp.where(j_to_low, dphi_j, state["dphi_low"])
|
||||
|
||||
# next iteration
|
||||
state["j"] = state["j"] + 1
|
||||
state["j"] += 1
|
||||
|
||||
state["failed"] = state["j"] == maxiter
|
||||
return state
|
||||
|
@ -194,8 +187,7 @@ class LineSearch(nn.Cell):
|
|||
super(LineSearch, self).__init__()
|
||||
self.func = func
|
||||
|
||||
def construct(self, xk, pk, old_fval=None, old_old_fval=None, gfk=None,
|
||||
c1=1e-4, c2=0.9, maxiter=3):
|
||||
def construct(self, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, c2=0.9, maxiter=20):
|
||||
def fval_and_grad(alpha):
|
||||
xkk = xk + alpha * pk
|
||||
fkk = self.func(xkk)
|
||||
|
@ -284,7 +276,6 @@ class LineSearch(nn.Cell):
|
|||
state["g_star"] = mnp.where(cond3, zoom2["g_star"], state["g_star"])
|
||||
state["dphi_star"] = mnp.where(cond3, zoom2["dphi_star"], state["dphi_star"])
|
||||
|
||||
# next iteration
|
||||
state["i"] += 1
|
||||
state["a_i"] = a_i
|
||||
state["phi_i"] = phi_i
|
||||
|
@ -308,8 +299,7 @@ class LineSearch(nn.Cell):
|
|||
return state
|
||||
|
||||
|
||||
def line_search(f, xk, pk, gfk=None, old_fval=None, old_old_fval=None, c1=1e-4,
|
||||
c2=0.9, maxiter=20):
|
||||
def line_search(f, xk, pk, gfk=None, old_fval=None, old_old_fval=None, c1=1e-4, c2=0.9, maxiter=20):
|
||||
"""Inexact line search that satisfies strong Wolfe conditions.
|
||||
|
||||
Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
"""minimize"""
|
||||
from typing import Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
from ...common import Tensor
|
||||
|
||||
from ._bfgs import minimize_bfgs
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Sparse linear algebra submodule"""
|
||||
from ... import nn, ms_function
|
||||
from ... import nn
|
||||
from ... import numpy as mnp
|
||||
from ...ops import functional as F
|
||||
from ...common import Tensor, CSRTensor, dtype as mstype
|
||||
|
@ -26,9 +26,7 @@ from ..utils_const import _raise_value_error, _raise_type_error, is_within_graph
|
|||
|
||||
|
||||
def gram_schmidt(Q, q):
|
||||
"""
|
||||
do Gram–Schmidt process to normalize vector v
|
||||
"""
|
||||
"""Do Gram–Schmidt process to normalize vector v"""
|
||||
h = mnp.dot(Q.T, q)
|
||||
Qh = mnp.dot(Q, h)
|
||||
q = q - Qh
|
||||
|
@ -36,7 +34,7 @@ def gram_schmidt(Q, q):
|
|||
|
||||
|
||||
def arnoldi_iteration(k, A, M, V, H):
|
||||
""" Performs a single (the k'th) step of the Arnoldi process."""
|
||||
"""Performs a single (the k'th) step of the Arnoldi process."""
|
||||
v_ = V[..., k]
|
||||
v = M(A(v_))
|
||||
v, h = gram_schmidt(V, v)
|
||||
|
@ -50,8 +48,8 @@ def arnoldi_iteration(k, A, M, V, H):
|
|||
return V, H, breakdown
|
||||
|
||||
|
||||
@ms_function
|
||||
def rotate_vectors(H, i, cs, sn):
|
||||
"""Rotate vectors."""
|
||||
x1 = H[i]
|
||||
y1 = H[i + 1]
|
||||
x2 = cs * x1 - sn * y1
|
||||
|
@ -62,6 +60,7 @@ def rotate_vectors(H, i, cs, sn):
|
|||
|
||||
|
||||
def _high_precision_cho_solve(a, b, data_type=mstype.float64):
|
||||
"""As a core computing module of gmres, cholesky solver must explicitly cast to double precision."""
|
||||
a = a.astype(mstype.float64)
|
||||
b = b.astype(mstype.float64)
|
||||
a_a = mnp.dot(a, a.T)
|
||||
|
@ -378,8 +377,10 @@ def _cg(A, b, x0, tol, atol, maxiter, M):
|
|||
|
||||
|
||||
class CG(nn.Cell):
|
||||
"""Figure 2.5 from Barrett R, et al. 'Templates for the sulution of linear systems:
|
||||
building blocks for iterative methods', 1994, pg. 12-14
|
||||
"""Use Conjugate Gradient iteration to solve the linear system:
|
||||
|
||||
.. math::
|
||||
A x = b
|
||||
"""
|
||||
|
||||
def __init__(self, A, M):
|
||||
|
|
Loading…
Reference in New Issue