forked from mindspore-Ecosystem/mindspore
add sparse gmres api && check && bprop
This commit is contained in:
@ -61,6 +61,17 @@ def rotate_vectors(H, i, cs, sn):
return H
def _high_precision_cho_solve(a, b, data_type=mstype.float64):
a = a.astype(mstype.float64)
b = b.astype(mstype.float64)
a_a =, a.T)
a_b =, b)
c, lower = cho_factor(a_a, lower=False)
factor = (c, lower)
y = cho_solve(factor, a_b)
return y.astype(data_type)
class BatchedGmres(nn.Cell):
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
@ -97,14 +108,10 @@ class BatchedGmres(nn.Cell):
k_iter += 1
beta_vec = mnp.zeros((restart + 1,), dtype=dtype)
beta_vec[0] = residual_norm
a2 =, H.T)
b2 =, beta_vec)
c, lower = cho_factor(a2, lower=False)
factor = (c, lower)
y = cho_solve(factor, b2)
y = _high_precision_cho_solve(H, beta_vec, data_type=dtype)
dx =[..., :-1], y)
x = x + dx
residual = b - A(x)
residual = M(b - A(x))
unit_residual, residual_norm = _safe_normalize(residual)
k += 1
@ -189,8 +196,8 @@ class IterativeGmres(nn.Cell):
return x0, > atol, iters, _INT_ZERO)
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
M=None, solve_method='batched'):
def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
M=None, callback=None, restrt=None, atol=0.0, callback_type=None, solve_method='batched'):
Given given A and b, GMRES solves the linear system:
@ -216,23 +223,32 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
:math:`norm(residual) <= max(tol*norm(b), atol)`. We do not implement SciPy's
"legacy" behavior, so MindSpore's tolerance will differ from SciPy unless you
explicitly pass `atol` to SciPy's `gmres`. Default: 1e-5.
atol (float, optional): The same as `tol`. Default: 0.0.
restart (integer, optional): Size of the Krylov subspace ("number of iterations")
built between restarts. GMRES works by approximating the true solution x as its
projection into a Krylov space of this dimension - this parameter
therefore bounds the maximum accuracy achievable from any guess
solution. Larger values increase both number of iterations and iteration
cost, but may be necessary for convergence. The algorithm terminates
early if convergence is achieved before the full subspace is built.
Default: 20.
early if convergence is achieved before the full subspace is built. Default: 20.
maxiter (int): Maximum number of times to rebuild the size-`restart`
Krylov space starting from the solution found at the last iteration. If GMRES
halts or is very slow, decreasing this parameter may help.
Default: None.
halts or is very slow, decreasing this parameter may help. Default: None.
M (Union[Tensor, function]): Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance. Default: None.
callback (function): User-supplied function to call after each iteration. It is called as callback(args),
where args are selected by callback_type. Default: None.
restrt (int, optional): Deprecated, use restart instead. Default: None.
atol (float, optional): The same as `tol`. Default: 0.0.
callback_type (str, optional): Callback function argument requested:
Default: None.
- x: current iterate (ndarray), called on every restart
- pr_norm: relative (preconditioned) residual norm (float), called on every inner iteration
- legacy (default): same as pr_norm, but also changes the meaning of ‘maxiter’ to count inner
iterations instead of restart cycles.
solve_method (str): There are two kinds of solve methods,'incremental' or 'batched'. Default: "batched".
- incremental: builds a QR decomposition for the Krylov subspace incrementally during
@ -268,12 +284,19 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
size = b.size
if maxiter is None:
maxiter = 10 * size # copied from scipy
if restart > size:
restart = size
if M is None:
M = lambda x: x
func_name = "gmres"
_type_check(func_name, tol, float, 'tol')
_type_check(func_name, restart, int, 'restart')
_type_check(func_name, maxiter, int, 'maxiter')
_type_check(func_name, solve_method, str, 'solve_method')
_value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
_value_check(func_name, restrt, None, 'restrt', op='is', fmt='todo')
_value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo')
if restart > size:
restart = size
A, M, b, x0 = _sparse_check(func_name, A, M, b, x0)
if solve_method == 'incremental':
x, info = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter)
elif solve_method == 'batched':
@ -25,19 +25,19 @@ from mindspore.common import Tensor
from import create_sym_pos_matrix, create_full_rank_matrix, to_tensor, to_ndarray, get_platform
def _fetch_preconditioner(preconditioner, A):
def _fetch_preconditioner(preconditioner, a):
Returns one of various preconditioning matrices depending on the identifier
`preconditioner' and the input matrix A whose inverse it supposedly
if preconditioner == 'identity':
M = onp.eye(A.shape[0], dtype=A.dtype)
M = onp.eye(a.shape[0], dtype=a.dtype)
elif preconditioner == 'random':
random_metrix = create_sym_pos_matrix(A.shape, A.dtype)
M = onp.linalg.inv(random_metrix)
random_matrix = create_sym_pos_matrix(a.shape, a.dtype)
M = onp.linalg.inv(random_matrix)
elif preconditioner == 'exact':
M = onp.linalg.inv(A)
M = onp.linalg.inv(a)
M = None
return M
@ -305,125 +305,78 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
@pytest.mark.parametrize('n', [128])
@pytest.mark.parametrize('dtype,error', [(onp.float32, 1e-4)])
@pytest.mark.parametrize('restart', [1])
@pytest.mark.parametrize('maxiter', [1])
@pytest.mark.parametrize('preconditioner', ['random'])
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
def test_gmres_against_scipy_level1(n, dtype, error, restart, maxiter, preconditioner, solve_method):
Feature: ALL TO ALL
Description: level1 test cases for [N x N] X [N X 1]
Expectation: the result match scipy
a = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype)
x0 = onp.zeros_like(b).astype(dtype)
M = _fetch_preconditioner(preconditioner, a)
tol = float(onp.finfo(dtype=dtype).eps)
atol = tol
if preconditioner == 'random':
restart = n
maxiter = None
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
M=M, atol=atol, solve_method=solve_method)
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
@pytest.mark.parametrize('n', [3, 5, 7])
@pytest.mark.parametrize('dtype,tol', [(onp.float64, 7), (onp.float32, 3)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
def test_gmres_incremental_against_scipy(n, tol, dtype, preconditioner):
@pytest.mark.parametrize('n', [3, 7])
@pytest.mark.parametrize('dtype,error', [(onp.float64, 1e-5), (onp.float32, 1e-4)])
@pytest.mark.parametrize('restart', [1, 2])
@pytest.mark.parametrize('maxiter', [1, 2])
@pytest.mark.parametrize('preconditioner', ['identity', 'exact', 'random'])
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner, solve_method):
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
Expectation: the result match scipy
A = create_full_rank_matrix((n, n), dtype)
a = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype)
x0 = onp.zeros_like(b).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, M=M)
A = Tensor(A)
b = Tensor(b)
x0 = Tensor(x0)
if M is not None:
M = Tensor(M)
gmres_x, _ = msp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, solve_method='incremental', M=M)
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=tol)
@pytest.mark.parametrize('n', [3, 5, 7])
@pytest.mark.parametrize('dtype, tol', [(onp.float64, 7), (onp.float32, 3)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
def test_gmres_incremental_against_scipy_graph(n, tol, dtype, preconditioner):
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
Expectation: the result match scipy
A = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype)
x0 = onp.zeros_like(b).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, M=M)
A = Tensor(A)
b = Tensor(b)
x0 = Tensor(x0)
if M is not None:
M = Tensor(M)
gmres_x, _ = msp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, solve_method='incremental', M=M)
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=tol)
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('dtype, tol', [(onp.float64, 7), (onp.float32, 3)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 2])
def test_pynative_batched_gmres_against_scipy(n, dtype, tol, preconditioner, maxiter):
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
M = _fetch_preconditioner(preconditioner, a)
tol = float(onp.finfo(dtype=dtype).eps)
atol = tol
if preconditioner == 'random':
restart = n
maxiter = None
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
# PyNative Mode
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
M = _fetch_preconditioner(preconditioner, a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
M = Tensor(M) if M is not None else M
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
M=M, atol=atol, solve_method=solve_method)
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6,
onp.testing.assert_almost_equal(msp_x.asnumpy(), osp_x, decimal=tol)
@pytest.mark.parametrize('n', [5, 6])
@pytest.mark.parametrize('dtype, tol', [(onp.float64, 7), (onp.float32, 3)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 2])
def test_graph_batched_gmres_against_scipy(n, dtype, tol, preconditioner, maxiter):
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
# Graph Mode
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
M = _fetch_preconditioner(preconditioner, a)
M = Tensor(M) if M is not None else M
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=0.0)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=0.0, solve_method='batched')
onp.testing.assert_almost_equal(msp_x.asnumpy(), osp_x, decimal=tol)
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
M=M, atol=atol, solve_method=solve_method)
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
@ -19,6 +19,7 @@ from functools import cmp_to_key
import numpy as onp
import scipy.sparse.linalg
from scipy.linalg import eigvals
from mindspore import Tensor, CSRTensor
import mindspore.ops as ops
import mindspore.numpy as mnp
@ -127,8 +128,14 @@ def create_sym_pos_matrix(shape, dtype):
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
n = shape[-1]
x = onp.random.random(shape)
return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
count = 0
while count < 100:
x = onp.random.random(shape).astype(dtype)
a = (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
count += 1
if onp.min(eigvals(a)) > 0:
return a
raise ValueError('Symmetric positive definite matrix create failed')
def gradient_check(x, net, epsilon=1e-3, symmetric=False, enumerate_fn=onp.ndenumerate):
Reference in New Issue