!31798 Fix python code check for scipy module.

Merge pull request !31798 from hezhenhao1/fix_code_check
This commit is contained in:
i-robot 2022-03-24 07:44:01 +00:00 committed by Gitee
commit 12f6fb3d7f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 27 additions and 113 deletions

View File

@ -155,57 +155,6 @@ class Cholesky(PrimitiveWithInfer):
return output
class CholeskySolve(PrimitiveWithInfer):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
lower : bool, optional
Whether to compute the upper or lower triangular Cholesky factorization
(Default: upper-triangular)
b : array
Right-hand side
Inputs:
- **A** (Tensor) - A matrix of shape :math:`(M, M)` to be decomposed.
- **b** (Tensor) - A Tensor of shape :math:`(M,)` or :math:`(..., M)`.
Right-hand side matrix in :math:`A x = b`.
Returns
-------
x : array
The solution to the system A x = b
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.ops import CholeskySolve
>>> from mindspore.scipy.linalg import cho_factor
>>> A = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]], dtype=onp.float32))
>>> b = Tensor(onp.array([1.0, 1.0, 1.0, 1.0], dtype=onp.float32))
>>> c, lower = cho_factor(A)
>>> cholesky_solver = CholeskySolve(lower=lower)
>>> x = cholesky_solver(c, b)
>>> print(x)
[-0.01749266 0.11953348 0.01166185 0.15743434]
"""
@prim_attr_register
def __init__(self, lower=False):
super().__init__(name="CholeskySolve")
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
self.init_prim_io_names(inputs=['A', 'b'], outputs=['y'])
def __infer__(self, A, b):
b_shape = b['shape']
a_dtype = A['dtype']
return {
'shape': tuple(b_shape),
'dtype': a_dtype,
'value': None
}
class Eigh(PrimitiveWithInfer):
"""
Eigh decomposition(Symmetric matrix)
@ -309,28 +258,6 @@ class LU(PrimitiveWithInfer):
return output
class LUSolver(PrimitiveWithInfer):
"""
LUSolver for Ax = b
"""
@prim_attr_register
def __init__(self, trans: str):
super().__init__(name="LUSolver")
self.init_prim_io_names(inputs=['a', 'b'], outputs=['output'])
self.trans = validator.check_value_type("trans", trans, [str], self.name)
def __infer__(self, a, b):
b_shape = list(b['shape'])
a_dtype = a['dtype']
output = {
'shape': tuple(b_shape),
'dtype': a_dtype,
'value': None
}
return output
class MatrixSetDiag(PrimitiveWithInfer):
"""
Inner API to set a [..., M, N] matrix's diagonals by range[k[0], k[1]].

View File

@ -14,11 +14,9 @@
# ============================================================================
"""BFGS"""
from typing import NamedTuple
from ... import nn
from ... import numpy as mnp
from ...common import Tensor
from .line_search import LineSearch
from ..utils import _to_scalar, _to_tensor, grad, _norm
@ -82,7 +80,7 @@ class MinimizeBfgs(nn.Cell):
maxiter = mnp.size(x0) * 200
d = x0.shape[0]
I = mnp.eye(d, dtype=x0.dtype)
identity = mnp.eye(d, dtype=x0.dtype)
f_0 = self.func(x0)
g_0 = grad(self.func)(x0)
@ -96,7 +94,7 @@ class MinimizeBfgs(nn.Cell):
"x_k": x0,
"f_k": f_0,
"g_k": g_0,
"H_k": I,
"H_k": identity,
"old_old_fval": f_0 + _norm(g_0) / 2,
"status": _INT_ZERO,
"line_search_status": _INT_ZERO
@ -136,13 +134,13 @@ class MinimizeBfgs(nn.Cell):
rho_k = mnp.reciprocal(mnp.dot(y_k, s_k))
sy_k = mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(y_k, axis=0)
term1 = rho_k * sy_k
sy_k_2 = mnp.expand_dims(y_k, axis=1) * mnp.expand_dims(s_k, axis=0)
term2 = rho_k * sy_k_2
term3 = mnp.matmul(I - term1, state["H_k"])
term4 = mnp.matmul(term3, I - term2)
ys_k = mnp.expand_dims(y_k, axis=1) * mnp.expand_dims(s_k, axis=0)
term2 = rho_k * ys_k
term3 = mnp.matmul(identity - term1, state["H_k"])
term4 = mnp.matmul(term3, identity - term2)
term5 = rho_k * (mnp.expand_dims(s_k, axis=1) * mnp.expand_dims(s_k, axis=0))
H_kp1 = term4 + term5
state["H_k"] = H_kp1
hess_kp1 = term4 + term5
state["H_k"] = hess_kp1
# next iteration
state["k"] = state["k"] + 1

View File

@ -14,12 +14,10 @@
# ============================================================================
"""line search"""
from typing import NamedTuple
from ... import nn
from ... import numpy as mnp
from ...common import dtype as mstype
from ...common import Tensor
from ..utils import _to_scalar, _to_tensor, grad
@ -52,7 +50,6 @@ def _cubicmin(a, fa, fpa, b, fb, c, fc):
"""Finds the minimizer for a cubic polynomial that goes through the
points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa.
"""
C = fpa
db = b - a
dc = c - a
denom = (db * dc) ** 2 * (db - dc)
@ -64,13 +61,12 @@ def _cubicmin(a, fa, fpa, b, fb, c, fc):
d1[1, 1] = db ** 3
d2 = mnp.zeros((2,))
d2[0] = fb - fa - C * db
d2[1] = fc - fa - C * dc
d2[0] = fb - fa - fpa * db
d2[1] = fc - fa - fpa * dc
A, B = mnp.dot(d1, d2.flatten()) / denom
radical = B * B - 3. * A * C
xmin = a + (-B + mnp.sqrt(radical)) / (3. * A)
a2, b2 = mnp.dot(d1, d2) / denom
radical = b2 * b2 - 3. * a2 * fpa
xmin = a + (-b2 + mnp.sqrt(radical)) / (3. * a2)
return xmin
@ -78,11 +74,9 @@ def _quadmin(a, fa, fpa, b, fb):
"""Finds the minimizer for a quadratic polynomial that goes through
the points (a,fa), (b,fb) with derivative at a of fpa.
"""
D = fa
C = fpa
db = b - a
B = (fb - D - C * db) / (db ** 2)
xmin = a - C / (2. * B)
b2 = (fb - fa - fpa * db) / (db ** 2)
xmin = a - fpa / (2. * b2)
return xmin
@ -179,8 +173,7 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0,
state["phi_low"] = mnp.where(j_to_low, phi_j, state["phi_low"])
state["dphi_low"] = mnp.where(j_to_low, dphi_j, state["dphi_low"])
# next iteration
state["j"] = state["j"] + 1
state["j"] += 1
state["failed"] = state["j"] == maxiter
return state
@ -194,8 +187,7 @@ class LineSearch(nn.Cell):
super(LineSearch, self).__init__()
self.func = func
def construct(self, xk, pk, old_fval=None, old_old_fval=None, gfk=None,
c1=1e-4, c2=0.9, maxiter=3):
def construct(self, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, c2=0.9, maxiter=20):
def fval_and_grad(alpha):
xkk = xk + alpha * pk
fkk = self.func(xkk)
@ -284,7 +276,6 @@ class LineSearch(nn.Cell):
state["g_star"] = mnp.where(cond3, zoom2["g_star"], state["g_star"])
state["dphi_star"] = mnp.where(cond3, zoom2["dphi_star"], state["dphi_star"])
# next iteration
state["i"] += 1
state["a_i"] = a_i
state["phi_i"] = phi_i
@ -308,8 +299,7 @@ class LineSearch(nn.Cell):
return state
def line_search(f, xk, pk, gfk=None, old_fval=None, old_old_fval=None, c1=1e-4,
c2=0.9, maxiter=20):
def line_search(f, xk, pk, gfk=None, old_fval=None, old_old_fval=None, c1=1e-4, c2=0.9, maxiter=20):
"""Inexact line search that satisfies strong Wolfe conditions.
Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61

View File

@ -15,9 +15,7 @@
"""minimize"""
from typing import Optional
from typing import NamedTuple
from ...common import Tensor
from ._bfgs import minimize_bfgs

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Sparse linear algebra submodule"""
from ... import nn, ms_function
from ... import nn
from ... import numpy as mnp
from ...ops import functional as F
from ...common import Tensor, CSRTensor, dtype as mstype
@ -26,9 +26,7 @@ from ..utils_const import _raise_value_error, _raise_type_error, is_within_graph
def gram_schmidt(Q, q):
"""
do GramSchmidt process to normalize vector v
"""
"""Do GramSchmidt process to normalize vector v"""
h = mnp.dot(Q.T, q)
Qh = mnp.dot(Q, h)
q = q - Qh
@ -36,7 +34,7 @@ def gram_schmidt(Q, q):
def arnoldi_iteration(k, A, M, V, H):
""" Performs a single (the k'th) step of the Arnoldi process."""
"""Performs a single (the k'th) step of the Arnoldi process."""
v_ = V[..., k]
v = M(A(v_))
v, h = gram_schmidt(V, v)
@ -50,8 +48,8 @@ def arnoldi_iteration(k, A, M, V, H):
return V, H, breakdown
@ms_function
def rotate_vectors(H, i, cs, sn):
"""Rotate vectors."""
x1 = H[i]
y1 = H[i + 1]
x2 = cs * x1 - sn * y1
@ -62,6 +60,7 @@ def rotate_vectors(H, i, cs, sn):
def _high_precision_cho_solve(a, b, data_type=mstype.float64):
"""As a core computing module of gmres, cholesky solver must explicitly cast to double precision."""
a = a.astype(mstype.float64)
b = b.astype(mstype.float64)
a_a = mnp.dot(a, a.T)
@ -378,8 +377,10 @@ def _cg(A, b, x0, tol, atol, maxiter, M):
class CG(nn.Cell):
"""Figure 2.5 from Barrett R, et al. 'Templates for the sulution of linear systems:
building blocks for iterative methods', 1994, pg. 12-14
"""Use Conjugate Gradient iteration to solve the linear system:
.. math::
A x = b
"""
def __init__(self, A, M):