add sparse gmres api && check && bprop
This commit is contained in:
parent
22240df6c0
commit
95a34dcbfa
|
@ -61,6 +61,17 @@ def rotate_vectors(H, i, cs, sn):
|
|||
return H
|
||||
|
||||
|
||||
def _high_precision_cho_solve(a, b, data_type=mstype.float64):
|
||||
a = a.astype(mstype.float64)
|
||||
b = b.astype(mstype.float64)
|
||||
a_a = mnp.dot(a, a.T)
|
||||
a_b = mnp.dot(a, b)
|
||||
c, lower = cho_factor(a_a, lower=False)
|
||||
factor = (c, lower)
|
||||
y = cho_solve(factor, a_b)
|
||||
return y.astype(data_type)
|
||||
|
||||
|
||||
class BatchedGmres(nn.Cell):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
||||
|
@ -97,14 +108,10 @@ class BatchedGmres(nn.Cell):
|
|||
k_iter += 1
|
||||
beta_vec = mnp.zeros((restart + 1,), dtype=dtype)
|
||||
beta_vec[0] = residual_norm
|
||||
a2 = mnp.dot(H, H.T)
|
||||
b2 = mnp.dot(H, beta_vec)
|
||||
c, lower = cho_factor(a2, lower=False)
|
||||
factor = (c, lower)
|
||||
y = cho_solve(factor, b2)
|
||||
y = _high_precision_cho_solve(H, beta_vec, data_type=dtype)
|
||||
dx = mnp.dot(V[..., :-1], y)
|
||||
x = x + dx
|
||||
residual = b - A(x)
|
||||
residual = M(b - A(x))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
k += 1
|
||||
|
||||
|
@ -189,8 +196,8 @@ class IterativeGmres(nn.Cell):
|
|||
return x0, F.select(r_norm > atol, iters, _INT_ZERO)
|
||||
|
||||
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
M=None, solve_method='batched'):
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
|
||||
M=None, callback=None, restrt=None, atol=0.0, callback_type=None, solve_method='batched'):
|
||||
"""
|
||||
Given given A and b, GMRES solves the linear system:
|
||||
|
||||
|
@ -216,23 +223,32 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
|||
:math:`norm(residual) <= max(tol*norm(b), atol)`. We do not implement SciPy's
|
||||
"legacy" behavior, so MindSpore's tolerance will differ from SciPy unless you
|
||||
explicitly pass `atol` to SciPy's `gmres`. Default: 1e-5.
|
||||
atol (float, optional): The same as `tol`. Default: 0.0.
|
||||
restart (integer, optional): Size of the Krylov subspace ("number of iterations")
|
||||
built between restarts. GMRES works by approximating the true solution x as its
|
||||
projection into a Krylov space of this dimension - this parameter
|
||||
therefore bounds the maximum accuracy achievable from any guess
|
||||
solution. Larger values increase both number of iterations and iteration
|
||||
cost, but may be necessary for convergence. The algorithm terminates
|
||||
early if convergence is achieved before the full subspace is built.
|
||||
Default: 20.
|
||||
early if convergence is achieved before the full subspace is built. Default: 20.
|
||||
maxiter (int): Maximum number of times to rebuild the size-`restart`
|
||||
Krylov space starting from the solution found at the last iteration. If GMRES
|
||||
halts or is very slow, decreasing this parameter may help.
|
||||
Default: None.
|
||||
halts or is very slow, decreasing this parameter may help. Default: None.
|
||||
M (Union[Tensor, function]): Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance. Default: None.
|
||||
callback (function): User-supplied function to call after each iteration. It is called as callback(args),
|
||||
where args are selected by callback_type. Default: None.
|
||||
restrt (int, optional): Deprecated, use restart instead. Default: None.
|
||||
atol (float, optional): The same as `tol`. Default: 0.0.
|
||||
callback_type (str, optional): Callback function argument requested:
|
||||
Default: None.
|
||||
|
||||
- x: current iterate (ndarray), called on every restart
|
||||
- pr_norm: relative (preconditioned) residual norm (float), called on every inner iteration
|
||||
- legacy (default): same as pr_norm, but also changes the meaning of ‘maxiter’ to count inner
|
||||
iterations instead of restart cycles.
|
||||
|
||||
solve_method (str): There are two kinds of solve methods,'incremental' or 'batched'. Default: "batched".
|
||||
|
||||
- incremental: builds a QR decomposition for the Krylov subspace incrementally during
|
||||
|
@ -268,12 +284,19 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
|||
size = b.size
|
||||
if maxiter is None:
|
||||
maxiter = 10 * size # copied from scipy
|
||||
if restart > size:
|
||||
restart = size
|
||||
|
||||
if M is None:
|
||||
M = lambda x: x
|
||||
|
||||
func_name = "gmres"
|
||||
_type_check(func_name, tol, float, 'tol')
|
||||
_type_check(func_name, restart, int, 'restart')
|
||||
_type_check(func_name, maxiter, int, 'maxiter')
|
||||
_type_check(func_name, solve_method, str, 'solve_method')
|
||||
_value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
|
||||
_value_check(func_name, restrt, None, 'restrt', op='is', fmt='todo')
|
||||
_value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo')
|
||||
if restart > size:
|
||||
restart = size
|
||||
A, M, b, x0 = _sparse_check(func_name, A, M, b, x0)
|
||||
if solve_method == 'incremental':
|
||||
x, info = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter)
|
||||
elif solve_method == 'batched':
|
||||
|
|
|
@ -25,19 +25,19 @@ from mindspore.common import Tensor
|
|||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor, to_ndarray, get_platform
|
||||
|
||||
|
||||
def _fetch_preconditioner(preconditioner, A):
|
||||
def _fetch_preconditioner(preconditioner, a):
|
||||
"""
|
||||
Returns one of various preconditioning matrices depending on the identifier
|
||||
`preconditioner' and the input matrix A whose inverse it supposedly
|
||||
approximates.
|
||||
"""
|
||||
if preconditioner == 'identity':
|
||||
M = onp.eye(A.shape[0], dtype=A.dtype)
|
||||
M = onp.eye(a.shape[0], dtype=a.dtype)
|
||||
elif preconditioner == 'random':
|
||||
random_metrix = create_sym_pos_matrix(A.shape, A.dtype)
|
||||
M = onp.linalg.inv(random_metrix)
|
||||
random_matrix = create_sym_pos_matrix(a.shape, a.dtype)
|
||||
M = onp.linalg.inv(random_matrix)
|
||||
elif preconditioner == 'exact':
|
||||
M = onp.linalg.inv(A)
|
||||
M = onp.linalg.inv(a)
|
||||
else:
|
||||
M = None
|
||||
return M
|
||||
|
@ -305,125 +305,78 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
|||
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [128])
|
||||
@pytest.mark.parametrize('dtype,error', [(onp.float32, 1e-4)])
|
||||
@pytest.mark.parametrize('restart', [1])
|
||||
@pytest.mark.parametrize('maxiter', [1])
|
||||
@pytest.mark.parametrize('preconditioner', ['random'])
|
||||
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
|
||||
def test_gmres_against_scipy_level1(n, dtype, error, restart, maxiter, preconditioner, solve_method):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: level1 test cases for [N x N] X [N X 1]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.rand(n).astype(dtype)
|
||||
x0 = onp.zeros_like(b).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
tol = float(onp.finfo(dtype=dtype).eps)
|
||||
atol = tol
|
||||
if preconditioner == 'random':
|
||||
restart = n
|
||||
maxiter = None
|
||||
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
|
||||
M=M, atol=atol, solve_method=solve_method)
|
||||
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [3, 5, 7])
|
||||
@pytest.mark.parametrize('dtype,tol', [(onp.float64, 7), (onp.float32, 3)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
def test_gmres_incremental_against_scipy(n, tol, dtype, preconditioner):
|
||||
@pytest.mark.parametrize('n', [3, 7])
|
||||
@pytest.mark.parametrize('dtype,error', [(onp.float64, 1e-5), (onp.float32, 1e-4)])
|
||||
@pytest.mark.parametrize('restart', [1, 2])
|
||||
@pytest.mark.parametrize('maxiter', [1, 2])
|
||||
@pytest.mark.parametrize('preconditioner', ['identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
|
||||
def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner, solve_method):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N X 1]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
A = create_full_rank_matrix((n, n), dtype)
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.rand(n).astype(dtype)
|
||||
x0 = onp.zeros_like(b).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, A)
|
||||
|
||||
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, M=M)
|
||||
A = Tensor(A)
|
||||
b = Tensor(b)
|
||||
x0 = Tensor(x0)
|
||||
if M is not None:
|
||||
M = Tensor(M)
|
||||
|
||||
gmres_x, _ = msp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, solve_method='incremental', M=M)
|
||||
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=tol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [3, 5, 7])
|
||||
@pytest.mark.parametrize('dtype, tol', [(onp.float64, 7), (onp.float32, 3)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
def test_gmres_incremental_against_scipy_graph(n, tol, dtype, preconditioner):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N X 1]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
A = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.rand(n).astype(dtype)
|
||||
x0 = onp.zeros_like(b).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, A)
|
||||
|
||||
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, M=M)
|
||||
A = Tensor(A)
|
||||
b = Tensor(b)
|
||||
x0 = Tensor(x0)
|
||||
if M is not None:
|
||||
M = Tensor(M)
|
||||
|
||||
gmres_x, _ = msp.sparse.linalg.gmres(A, b, x0, tol=1e-07, atol=0, solve_method='incremental', M=M)
|
||||
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=tol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('dtype, tol', [(onp.float64, 7), (onp.float32, 3)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [1, 2])
|
||||
def test_pynative_batched_gmres_against_scipy(n, dtype, tol, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for gmres
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
tol = float(onp.finfo(dtype=dtype).eps)
|
||||
atol = tol
|
||||
if preconditioner == 'random':
|
||||
restart = n
|
||||
maxiter = None
|
||||
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
|
||||
# PyNative Mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
shape = (n, n)
|
||||
a = create_full_rank_matrix(shape, dtype)
|
||||
b = onp.random.rand(n).astype(dtype=dtype)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
M = Tensor(M) if M is not None else M
|
||||
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
|
||||
M=M, atol=atol, solve_method=solve_method)
|
||||
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
|
||||
|
||||
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6)
|
||||
|
||||
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6,
|
||||
solve_method='batched')
|
||||
onp.testing.assert_almost_equal(msp_x.asnumpy(), osp_x, decimal=tol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [5, 6])
|
||||
@pytest.mark.parametrize('dtype, tol', [(onp.float64, 7), (onp.float32, 3)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [1, 2])
|
||||
def test_graph_batched_gmres_against_scipy(n, dtype, tol, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for gmres
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
# Graph Mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
shape = (n, n)
|
||||
a = create_full_rank_matrix(shape, dtype)
|
||||
b = onp.random.rand(n).astype(dtype=dtype)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
M = Tensor(M) if M is not None else M
|
||||
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=0.0)
|
||||
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=0.0, solve_method='batched')
|
||||
onp.testing.assert_almost_equal(msp_x.asnumpy(), osp_x, decimal=tol)
|
||||
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
|
||||
M=M, atol=atol, solve_method=solve_method)
|
||||
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -19,6 +19,7 @@ from functools import cmp_to_key
|
|||
|
||||
import numpy as onp
|
||||
import scipy.sparse.linalg
|
||||
from scipy.linalg import eigvals
|
||||
from mindspore import Tensor, CSRTensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.numpy as mnp
|
||||
|
@ -127,8 +128,14 @@ def create_sym_pos_matrix(shape, dtype):
|
|||
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
|
||||
|
||||
n = shape[-1]
|
||||
x = onp.random.random(shape)
|
||||
return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
|
||||
count = 0
|
||||
while count < 100:
|
||||
x = onp.random.random(shape).astype(dtype)
|
||||
a = (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
|
||||
count += 1
|
||||
if onp.min(eigvals(a)) > 0:
|
||||
return a
|
||||
raise ValueError('Symmetric positive definite matrix create failed')
|
||||
|
||||
|
||||
def gradient_check(x, net, epsilon=1e-3, symmetric=False, enumerate_fn=onp.ndenumerate):
|
||||
|
|
Loading…
Reference in New Issue