Add bicgstab method and its test cases.

This commit is contained in:
hezhenhao1 2022-01-15 14:51:47 +08:00
parent 441169bcd5
commit 16a1c3c76e
5 changed files with 194 additions and 39 deletions

View File

@ -20,7 +20,7 @@ from ... import numpy as mnp
from ...common import Tensor
from .line_search import LineSearch
from ..utils import _to_scalar, grad
from ..utils import _to_scalar, grad, _norm
from ..utils import _INT_ZERO, _INT_ONE, _BOOL_FALSE
@ -74,13 +74,6 @@ class MinimizeBfgs(nn.Cell):
self.line_search = LineSearch(func)
def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10):
def _my_norm(x, ord_=None):
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res
if maxiter is None:
maxiter = mnp.size(x0) * 200
@ -90,7 +83,7 @@ class MinimizeBfgs(nn.Cell):
g_0 = grad(self.func)(x0)
state = {
"converged": _my_norm(g_0, ord_=mnp.inf) < gtol,
"converged": _norm(g_0, ord_=mnp.inf) < gtol,
"failed": _BOOL_FALSE,
"k": _INT_ZERO,
"nfev": _INT_ONE,
@ -100,7 +93,7 @@ class MinimizeBfgs(nn.Cell):
"f_k": f_0,
"g_k": g_0,
"H_k": I,
"old_old_fval": f_0 + _my_norm(g_0) / 2,
"old_old_fval": f_0 + _norm(g_0) / 2,
"status": _INT_ZERO,
"line_search_status": _INT_ZERO
}
@ -128,7 +121,7 @@ class MinimizeBfgs(nn.Cell):
y_k = g_kp1 - state["g_k"]
state["old_old_fval"] = state["f_k"]
state["converged"] = _my_norm(g_kp1, ord_=norm) < gtol
state["converged"] = _norm(g_kp1, ord_=norm) < gtol
state["x_k"] = x_kp1
state["f_k"] = f_kp1
state["g_k"] = g_kp1

View File

@ -14,6 +14,6 @@
# ============================================================================
"""Sparse submodule"""
from . import linalg
from .linalg import cg, gmres
from .linalg import cg, gmres, bicgstab
__all__ = ["cg", "gmres"]
__all__ = ["cg", "gmres", "bicgstab"]

View File

@ -18,7 +18,7 @@ from ... import numpy as mnp
from ...ops import functional as F
from ..linalg import solve_triangular
from ..linalg import cho_factor, cho_solve
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps, float_types
from ..utils import _INT_ZERO, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps, float_types, _norm
from ..utils_const import _raise_value_error, _raise_type_error
@ -289,21 +289,14 @@ class CG(nn.Cell):
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
def _my_norm(x, ord_=None):
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res
atol_ = mnp.maximum(atol, tol * _my_norm(b))
atol_ = mnp.maximum(atol, tol * _norm(b))
r = b - A(x0)
z = p = M(r)
rho = mnp.dot(r, z)
k = _INT_ZERO
x = x0
while k < maxiter and _my_norm(r) > atol_:
while k < maxiter and _norm(r) > atol_:
q = A(p)
alpha = rho / mnp.dot(p, q)
x = x + alpha * p
@ -317,7 +310,7 @@ class CG(nn.Cell):
k += 1
return x
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
@ -393,5 +386,132 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
if (F.dtype(b) not in float_types) or (F.dtype(b) != F.dtype(x0)) or (F.dtype(b) != F.dtype(A)):
_raise_type_error('Input A, x0 and b must have same float types')
x = CG(A, M)(b, x0, tol, atol, maxiter)
return x, None
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
return x, info
class BiCGStab(nn.Cell):
"""Figure 2.10 from Barrett R, et al. 'Templates for the sulution of linear systems:
building blocks for iterative methods', 1994, pg. 24-25
"""
def __init__(self, A, M):
super(BiCGStab, self).__init__()
self.A = A
self.M = M
def construct(self, b, x0, tol, atol, maxiter):
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
_FLOAT_ONE = _to_tensor(1., dtype=b.dtype)
atol_ = mnp.maximum(atol, tol * _norm(b))
r = r_tilde = v = p = b - A(x0)
rho = alpha = omega = _FLOAT_ONE
k = _INT_ZERO
x = x0
while k < maxiter:
rho_ = mnp.dot(r_tilde, r)
if rho_ == 0. or omega == 0.:
k = _INT_NEG_ONE
break
beta = rho_ / rho * (alpha / omega)
p = r + beta * (p - omega * v)
p_hat = M(p)
v = A(p_hat)
alpha = rho_ / mnp.dot(r_tilde, v)
s = r - alpha * v
x = x + alpha * p_hat
if _norm(s) <= atol_:
break
s_hat = M(s)
t = A(s_hat)
omega = mnp.dot(t, s) / mnp.dot(t, t)
x = x + omega * s_hat
r = s - omega * t
if _norm(r) <= atol_:
break
rho = rho_
k += 1
return x, F.select(k == _INT_NEG_ONE or k >= maxiter, k, _INT_ZERO)
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Bi-Conjugate Gradient Stable iteration to solve :math:`Ax = b`.
The numerics of MindSpore's `bicgstab` should exact match SciPy's
`bicgstab` (up to numerical precision).
As with `cg`, derivatives of `bicgstab` are implemented via implicit
differentiation with another `bicgstab` solve, rather than by
differentiating *through* the solver. They will be accurate only if
both solves converge.
Note:
- In the future, MindSpore will report the number of iterations when convergence
is not achieved, like SciPy. Currently it is None, as a Placeholder.
- `bicgstab` is not supported on Windows platform yet.
Args:
A (Union[Tensor, function]): 2D Tensor or function that calculates the linear
map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`.
As function, `A` must return Tensor with the same structure and shape as its input matrix.
b (Tensor): Right hand side of the linear system representing a single vector. Can be
stored as a Tensor.
x0 (Tensor): Starting guess for the solution. Must have the same structure as `b`. Default: None.
tol (float, optional): Tolerances for convergence, :math:`norm(residual) <= max(tol*norm(b), atol)`.
We do not implement SciPy's "legacy" behavior, so MindSpore's tolerance will
differ from SciPy unless you explicitly pass `atol` to SciPy's `bicgstab`. Default: 1e-5.
atol (float, optional): The same as `tol`. Default: 0.0.
maxiter (int): Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved. Default: None.
M (Union[Tensor, function]): Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance. Default: None.
Returns:
- Tensor, the converged solution. Has the same structure as `b`.
- None, placeholder for convergence information.
Raises:
ValueError: If `x0` and `b` don't have the same structure.
TypeError: If `A`, `x0` and `b` don't have the same float types(`mstype.float32` or `mstype.float64`).
Supported Platforms:
``CPU`` ``GPU``
Examples:
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.sparse.linalg import bicgstab
>>> A = Tensor(onp.array([[1, 2], [2, 1]], dtype='float32'))
>>> b = Tensor(onp.array([1, -1], dtype='float32'))
>>> result, _ = bicgstab(A, b)
>>> print(result)
[-1. 1.]
"""
if x0 is None:
x0 = mnp.zeros_like(b)
if maxiter is None:
maxiter = 10 * b.shape[0]
if M is None:
M = lambda x: x
if x0.shape != b.shape:
_raise_value_error(
'Input x0 and b must have matching shapes: {} vs {}'.format(x0.shape, b.shape))
if (F.dtype(b) not in float_types) or (F.dtype(b) != F.dtype(x0)) or (F.dtype(b) != F.dtype(A)):
_raise_type_error('Input A, x0 and b must have same float types')
x, info = BiCGStab(A, M)(b, x0, tol, atol, maxiter)
return x, info

View File

@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""internal utility functions"""
import numpy as onp
from .. import nn, ops
from .. import numpy as mnp
from ..numpy import where, zeros_like, dot, greater
from ..ops import functional as F
from ..common import Tensor
@ -62,11 +62,7 @@ def _to_scalar(arr):
if isinstance(arr, Tensor):
if arr.shape:
return arr
arr = arr.asnumpy()
if isinstance(arr, onp.ndarray):
if arr.shape:
return arr
return arr.item()
return arr.asnumpy().item()
raise ValueError("{} are not supported.".format(type(arr)))
@ -130,3 +126,11 @@ def _normalize_matvec(f):
_raise_value_error(
'linear operator must be either a function or Tensor: but got {}'.format(F.typeof(f)))
return f
def _norm(x, ord_=None):
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res

View File

@ -16,11 +16,11 @@
import pytest
import numpy as onp
import scipy as osp
from scipy.sparse.linalg import cg as osp_cg
import scipy.sparse.linalg
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from mindspore.scipy.sparse.linalg import cg as msp_cg
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix
@ -61,7 +61,7 @@ def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
A = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
osp_res = osp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
A = Tensor(A)
b = Tensor(b)
@ -69,11 +69,11 @@ def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = msp_cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
@ -99,11 +99,11 @@ def test_cg_against_numpy(dtype, shape):
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
actual_dyn, _ = msp_cg(Tensor(A), Tensor(b))
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
actual_sta, _ = msp_cg(Tensor(A), Tensor(b))
actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
kw = {"atol": 1e-5, "rtol": 1e-5}
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
@ -229,3 +229,41 @@ def test_graph_batched_gmres_against_scipy(n, dtype, tol, preconditioner, maxite
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=0.0)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=0.0, solve_method='batched')
onp.testing.assert_almost_equal(msp_x.asnumpy(), osp_x, decimal=tol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype_tol', [(onp.float64, 1e-10)])
@pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 3])
def test_bicgstab_against_scipy(dtype_tol, shape, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for bicgstab
Expectation: the result match scipy
"""
onp.random.seed(0)
dtype, tol = dtype_tol
A = create_full_rank_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
osp_res = scipy.sparse.linalg.bicgstab(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
A = Tensor(A)
b = Tensor(b)
M = Tensor(M) if M is not None else M
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = msp.sparse.linalg.bicgstab(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = msp.sparse.linalg.bicgstab(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)