forked from mindspore-Ecosystem/mindspore
Add bicgstab method and its test cases.
This commit is contained in:
parent
441169bcd5
commit
16a1c3c76e
|
@ -20,7 +20,7 @@ from ... import numpy as mnp
|
|||
from ...common import Tensor
|
||||
|
||||
from .line_search import LineSearch
|
||||
from ..utils import _to_scalar, grad
|
||||
from ..utils import _to_scalar, grad, _norm
|
||||
from ..utils import _INT_ZERO, _INT_ONE, _BOOL_FALSE
|
||||
|
||||
|
||||
|
@ -74,13 +74,6 @@ class MinimizeBfgs(nn.Cell):
|
|||
self.line_search = LineSearch(func)
|
||||
|
||||
def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
|
||||
def _my_norm(x, ord_=None):
|
||||
if ord_ == mnp.inf:
|
||||
res = mnp.max(mnp.abs(x))
|
||||
else:
|
||||
res = mnp.sqrt(mnp.sum(x ** 2))
|
||||
return res
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = mnp.size(x0) * 200
|
||||
|
||||
|
@ -90,7 +83,7 @@ class MinimizeBfgs(nn.Cell):
|
|||
g_0 = grad(self.func)(x0)
|
||||
|
||||
state = {
|
||||
"converged": _my_norm(g_0, ord_=mnp.inf) < gtol,
|
||||
"converged": _norm(g_0, ord_=mnp.inf) < gtol,
|
||||
"failed": _BOOL_FALSE,
|
||||
"k": _INT_ZERO,
|
||||
"nfev": _INT_ONE,
|
||||
|
@ -100,7 +93,7 @@ class MinimizeBfgs(nn.Cell):
|
|||
"f_k": f_0,
|
||||
"g_k": g_0,
|
||||
"H_k": I,
|
||||
"old_old_fval": f_0 + _my_norm(g_0) / 2,
|
||||
"old_old_fval": f_0 + _norm(g_0) / 2,
|
||||
"status": _INT_ZERO,
|
||||
"line_search_status": _INT_ZERO
|
||||
}
|
||||
|
@ -128,7 +121,7 @@ class MinimizeBfgs(nn.Cell):
|
|||
y_k = g_kp1 - state["g_k"]
|
||||
|
||||
state["old_old_fval"] = state["f_k"]
|
||||
state["converged"] = _my_norm(g_kp1, ord_=norm) < gtol
|
||||
state["converged"] = _norm(g_kp1, ord_=norm) < gtol
|
||||
state["x_k"] = x_kp1
|
||||
state["f_k"] = f_kp1
|
||||
state["g_k"] = g_kp1
|
||||
|
|
|
@ -14,6 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Sparse submodule"""
|
||||
from . import linalg
|
||||
from .linalg import cg, gmres
|
||||
from .linalg import cg, gmres, bicgstab
|
||||
|
||||
__all__ = ["cg", "gmres"]
|
||||
__all__ = ["cg", "gmres", "bicgstab"]
|
||||
|
|
|
@ -18,7 +18,7 @@ from ... import numpy as mnp
|
|||
from ...ops import functional as F
|
||||
from ..linalg import solve_triangular
|
||||
from ..linalg import cho_factor, cho_solve
|
||||
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps, float_types
|
||||
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps, float_types, _norm
|
||||
from ..utils_const import _raise_value_error, _raise_type_error
|
||||
|
||||
|
||||
|
@ -289,21 +289,14 @@ class CG(nn.Cell):
|
|||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
|
||||
def _my_norm(x, ord_=None):
|
||||
if ord_ == mnp.inf:
|
||||
res = mnp.max(mnp.abs(x))
|
||||
else:
|
||||
res = mnp.sqrt(mnp.sum(x ** 2))
|
||||
return res
|
||||
|
||||
atol_ = mnp.maximum(atol, tol * _my_norm(b))
|
||||
atol_ = mnp.maximum(atol, tol * _norm(b))
|
||||
|
||||
r = b - A(x0)
|
||||
z = p = M(r)
|
||||
rho = mnp.dot(r, z)
|
||||
k = _INT_ZERO
|
||||
x = x0
|
||||
while k < maxiter and _my_norm(r) > atol_:
|
||||
while k < maxiter and _norm(r) > atol_:
|
||||
q = A(p)
|
||||
alpha = rho / mnp.dot(p, q)
|
||||
x = x + alpha * p
|
||||
|
@ -317,7 +310,7 @@ class CG(nn.Cell):
|
|||
|
||||
k += 1
|
||||
|
||||
return x
|
||||
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
|
@ -393,5 +386,132 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
|||
if (F.dtype(b) not in float_types) or (F.dtype(b) != F.dtype(x0)) or (F.dtype(b) != F.dtype(A)):
|
||||
_raise_type_error('Input A, x0 and b must have same float types')
|
||||
|
||||
x = CG(A, M)(b, x0, tol, atol, maxiter)
|
||||
return x, None
|
||||
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
|
||||
return x, info
|
||||
|
||||
|
||||
class BiCGStab(nn.Cell):
|
||||
"""Figure 2.10 from Barrett R, et al. 'Templates for the sulution of linear systems:
|
||||
building blocks for iterative methods', 1994, pg. 24-25
|
||||
"""
|
||||
|
||||
def __init__(self, A, M):
|
||||
super(BiCGStab, self).__init__()
|
||||
self.A = A
|
||||
self.M = M
|
||||
|
||||
def construct(self, b, x0, tol, atol, maxiter):
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
|
||||
_FLOAT_ONE = _to_tensor(1., dtype=b.dtype)
|
||||
atol_ = mnp.maximum(atol, tol * _norm(b))
|
||||
|
||||
r = r_tilde = v = p = b - A(x0)
|
||||
rho = alpha = omega = _FLOAT_ONE
|
||||
k = _INT_ZERO
|
||||
x = x0
|
||||
while k < maxiter:
|
||||
rho_ = mnp.dot(r_tilde, r)
|
||||
if rho_ == 0. or omega == 0.:
|
||||
k = _INT_NEG_ONE
|
||||
break
|
||||
|
||||
beta = rho_ / rho * (alpha / omega)
|
||||
p = r + beta * (p - omega * v)
|
||||
p_hat = M(p)
|
||||
v = A(p_hat)
|
||||
alpha = rho_ / mnp.dot(r_tilde, v)
|
||||
s = r - alpha * v
|
||||
x = x + alpha * p_hat
|
||||
if _norm(s) <= atol_:
|
||||
break
|
||||
|
||||
s_hat = M(s)
|
||||
t = A(s_hat)
|
||||
omega = mnp.dot(t, s) / mnp.dot(t, t)
|
||||
x = x + omega * s_hat
|
||||
r = s - omega * t
|
||||
if _norm(r) <= atol_:
|
||||
break
|
||||
|
||||
rho = rho_
|
||||
k += 1
|
||||
|
||||
return x, F.select(k == _INT_NEG_ONE or k >= maxiter, k, _INT_ZERO)
|
||||
|
||||
|
||||
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
"""Use Bi-Conjugate Gradient Stable iteration to solve :math:`Ax = b`.
|
||||
|
||||
The numerics of MindSpore's `bicgstab` should exact match SciPy's
|
||||
`bicgstab` (up to numerical precision).
|
||||
|
||||
As with `cg`, derivatives of `bicgstab` are implemented via implicit
|
||||
differentiation with another `bicgstab` solve, rather than by
|
||||
differentiating *through* the solver. They will be accurate only if
|
||||
both solves converge.
|
||||
|
||||
Note:
|
||||
- In the future, MindSpore will report the number of iterations when convergence
|
||||
is not achieved, like SciPy. Currently it is None, as a Placeholder.
|
||||
|
||||
- `bicgstab` is not supported on Windows platform yet.
|
||||
|
||||
Args:
|
||||
A (Union[Tensor, function]): 2D Tensor or function that calculates the linear
|
||||
map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`.
|
||||
As function, `A` must return Tensor with the same structure and shape as its input matrix.
|
||||
b (Tensor): Right hand side of the linear system representing a single vector. Can be
|
||||
stored as a Tensor.
|
||||
x0 (Tensor): Starting guess for the solution. Must have the same structure as `b`. Default: None.
|
||||
tol (float, optional): Tolerances for convergence, :math:`norm(residual) <= max(tol*norm(b), atol)`.
|
||||
We do not implement SciPy's "legacy" behavior, so MindSpore's tolerance will
|
||||
differ from SciPy unless you explicitly pass `atol` to SciPy's `bicgstab`. Default: 1e-5.
|
||||
atol (float, optional): The same as `tol`. Default: 0.0.
|
||||
maxiter (int): Maximum number of iterations. Iteration will stop after maxiter
|
||||
steps even if the specified tolerance has not been achieved. Default: None.
|
||||
M (Union[Tensor, function]): Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance. Default: None.
|
||||
|
||||
Returns:
|
||||
- Tensor, the converged solution. Has the same structure as `b`.
|
||||
- None, placeholder for convergence information.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x0` and `b` don't have the same structure.
|
||||
TypeError: If `A`, `x0` and `b` don't have the same float types(`mstype.float32` or `mstype.float64`).
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.scipy.sparse.linalg import bicgstab
|
||||
>>> A = Tensor(onp.array([[1, 2], [2, 1]], dtype='float32'))
|
||||
>>> b = Tensor(onp.array([1, -1], dtype='float32'))
|
||||
>>> result, _ = bicgstab(A, b)
|
||||
>>> print(result)
|
||||
[-1. 1.]
|
||||
"""
|
||||
if x0 is None:
|
||||
x0 = mnp.zeros_like(b)
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = 10 * b.shape[0]
|
||||
|
||||
if M is None:
|
||||
M = lambda x: x
|
||||
|
||||
if x0.shape != b.shape:
|
||||
_raise_value_error(
|
||||
'Input x0 and b must have matching shapes: {} vs {}'.format(x0.shape, b.shape))
|
||||
|
||||
if (F.dtype(b) not in float_types) or (F.dtype(b) != F.dtype(x0)) or (F.dtype(b) != F.dtype(A)):
|
||||
_raise_type_error('Input A, x0 and b must have same float types')
|
||||
|
||||
x, info = BiCGStab(A, M)(b, x0, tol, atol, maxiter)
|
||||
return x, info
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""internal utility functions"""
|
||||
import numpy as onp
|
||||
from .. import nn, ops
|
||||
from .. import numpy as mnp
|
||||
from ..numpy import where, zeros_like, dot, greater
|
||||
from ..ops import functional as F
|
||||
from ..common import Tensor
|
||||
|
@ -62,11 +62,7 @@ def _to_scalar(arr):
|
|||
if isinstance(arr, Tensor):
|
||||
if arr.shape:
|
||||
return arr
|
||||
arr = arr.asnumpy()
|
||||
if isinstance(arr, onp.ndarray):
|
||||
if arr.shape:
|
||||
return arr
|
||||
return arr.item()
|
||||
return arr.asnumpy().item()
|
||||
raise ValueError("{} are not supported.".format(type(arr)))
|
||||
|
||||
|
||||
|
@ -130,3 +126,11 @@ def _normalize_matvec(f):
|
|||
_raise_value_error(
|
||||
'linear operator must be either a function or Tensor: but got {}'.format(F.typeof(f)))
|
||||
return f
|
||||
|
||||
|
||||
def _norm(x, ord_=None):
|
||||
if ord_ == mnp.inf:
|
||||
res = mnp.max(mnp.abs(x))
|
||||
else:
|
||||
res = mnp.sqrt(mnp.sum(x ** 2))
|
||||
return res
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
import pytest
|
||||
import numpy as onp
|
||||
import scipy as osp
|
||||
from scipy.sparse.linalg import cg as osp_cg
|
||||
import scipy.sparse.linalg
|
||||
|
||||
import mindspore.scipy as msp
|
||||
from mindspore import context
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.scipy.sparse.linalg import cg as msp_cg
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
|
|||
A = create_sym_pos_matrix(shape, dtype)
|
||||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, A)
|
||||
osp_res = osp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
|
||||
A = Tensor(A)
|
||||
b = Tensor(b)
|
||||
|
@ -69,11 +69,11 @@ def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
|
|||
|
||||
# using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
msp_res_dyn = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
|
||||
# using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
msp_res_sta = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
|
||||
|
@ -99,11 +99,11 @@ def test_cg_against_numpy(dtype, shape):
|
|||
|
||||
# using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
actual_dyn, _ = msp_cg(Tensor(A), Tensor(b))
|
||||
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
|
||||
|
||||
# using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actual_sta, _ = msp_cg(Tensor(A), Tensor(b))
|
||||
actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
|
||||
|
||||
kw = {"atol": 1e-5, "rtol": 1e-5}
|
||||
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
|
||||
|
@ -229,3 +229,41 @@ def test_graph_batched_gmres_against_scipy(n, dtype, tol, preconditioner, maxite
|
|||
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=0.0)
|
||||
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=0.0, solve_method='batched')
|
||||
onp.testing.assert_almost_equal(msp_x.asnumpy(), osp_x, decimal=tol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype_tol', [(onp.float64, 1e-10)])
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [1, 3])
|
||||
def test_bicgstab_against_scipy(dtype_tol, shape, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for bicgstab
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
dtype, tol = dtype_tol
|
||||
A = create_full_rank_matrix(shape, dtype)
|
||||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, A)
|
||||
osp_res = scipy.sparse.linalg.bicgstab(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
|
||||
A = Tensor(A)
|
||||
b = Tensor(b)
|
||||
M = Tensor(M) if M is not None else M
|
||||
|
||||
# using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
msp_res_dyn = msp.sparse.linalg.bicgstab(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
|
||||
# using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
msp_res_sta = msp.sparse.linalg.bicgstab(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)
|
||||
|
|
Loading…
Reference in New Issue