!31635 opt gmres code

Merge pull request !31635 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2022-03-22 01:30:27 +00:00 committed by Gitee
commit 6a07fffcac
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 190 additions and 127 deletions

View File

@ -72,24 +72,13 @@ def _high_precision_cho_solve(a, b, data_type=mstype.float64):
return y.astype(data_type)
class BatchedGmres(nn.Cell):
def _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M):
"""
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
This implementation solves a dense linear problem instead of building
a QR factorization during the Arnoldi process.
batched gmres: solve the least squares problem from scratch at the end of each GMRES iteration.
It does not allow for early termination, but has much less overhead on GPUs.
"""
def __init__(self, A, M):
super(BatchedGmres, self).__init__()
self.A = A
self.M = M
def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
dtype = b.dtype
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype)
@ -114,29 +103,16 @@ class BatchedGmres(nn.Cell):
residual = M(b - A(x))
unit_residual, residual_norm = _safe_normalize(residual)
k += 1
return x, F.select(residual_norm > atol, k, _INT_ZERO)
class IterativeGmres(nn.Cell):
def _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M):
"""
Implements a iterative GMRES. While building the ``restart``-dimensional
Krylov subspace iteratively using Givens Rotation method, the algorithm
constructs a Triangular matrix R which could be more easily solved.
incremental gmres: builds a QR decomposition for the Krylov subspace incrementally during
the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of
the residual norm that allows for early termination within a single "restart".
"""
def __init__(self, A, M):
super(IterativeGmres, self).__init__()
self.A = A
self.M = M
def construct(self, b, x0, tol, atol, restart, maxiter):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, atol)
@ -151,8 +127,8 @@ class IterativeGmres(nn.Cell):
while iters < maxiter and r_norm > atol:
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
dtype = mnp.result_type(b)
# use eye() to avoid constructing a singular matrix in case of early
# termination
# Use eye() to avoid constructing a singular matrix in case of early
# Termination
R = mnp.eye(restart, restart + 1, dtype=dtype)
givens = mnp.zeros((restart, 2), dtype=dtype)
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
@ -162,7 +138,7 @@ class IterativeGmres(nn.Cell):
err = r_norm
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
V, R, _ = arnoldi_iteration(k, A, M, V, R)
# givens rotation
# Givens rotation
row_k = R[k, :].copy()
i = _INT_ZERO
while i < k:
@ -192,10 +168,63 @@ class IterativeGmres(nn.Cell):
r, r_norm = _safe_normalize(r)
x0 = x
iters += 1
return x0, F.select(r_norm > atol, iters, _INT_ZERO)
class GMRES(nn.Cell):
"""
Given given A and b, GMRES solves the linear system:
.. math::
A x = b
"""
def __init__(self, A, M, solve_method):
super(GMRES, self).__init__()
self.A = A
self.M = M
self.solve_method = solve_method
def construct(self, b, x0, tol, atol, restart, maxiter):
# Constant tensor which avoids loop unrolling
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
x = x0
info = _to_tensor(0)
if self.solve_method == 'batched':
x, info = _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M)
elif self.solve_method == "incremental":
x, info = _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M)
else:
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", self.solve_method,
".")
return x, info
class GMRESV2(nn.Cell):
"""
This is a new version of GMRES, which contains all parameters in a graph.
"""
def __init__(self, solve_method):
super(GMRESV2, self).__init__()
self.solve_method = solve_method
def construct(self, A, b, x0, tol, atol, restart, maxiter, M):
A = _normalize_matvec(A)
M = _normalize_matvec(M)
x = x0
info = _to_tensor(0)
if self.solve_method == 'batched':
x, info = _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M)
elif self.solve_method == "incremental":
x, info = _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M)
else:
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", self.solve_method,
".")
return x, info
def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
M=None, callback=None, restrt=None, atol=0.0, callback_type=None, solve_method='batched'):
"""
@ -292,13 +321,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
_value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo')
if restart > size:
restart = size
if solve_method == 'incremental':
x, info = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter)
elif solve_method == 'batched':
x, info = BatchedGmres(A, M)(b, x0, tol, atol, restart, maxiter)
if not is_within_graph(A):
x, info = GMRES(A, M, solve_method)(b, x0, tol, atol, restart, maxiter)
else:
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", solve_method, ".")
x, info = GMRESV2(solve_method)(A, b, x0, tol, atol, restart, maxiter, M)
return x, info

View File

@ -43,6 +43,12 @@ def _fetch_preconditioner(preconditioner, a):
return M
def _is_valid_platform(tensor_type='Tensor'):
if tensor_type == "CSRTensor" and get_platform() != "linux":
return False
return True
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@ -58,7 +64,7 @@ def test_cg_against_scipy(tensor_type, dtype, tol, shape, preconditioner, maxite
Description: test cases for cg using function way in pynative/graph mode
Expectation: the result match scipy
"""
if tensor_type == "CSRTensor" and get_platform() != "linux":
if not _is_valid_platform(tensor_type):
return
onp.random.seed(0)
a = create_sym_pos_matrix(shape, dtype)
@ -70,11 +76,11 @@ def test_cg_against_scipy(tensor_type, dtype, tol, shape, preconditioner, maxite
b = Tensor(b)
m = to_tensor((m, tensor_type)) if m is not None else m
# using PYNATIVE MODE
# Using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
# using GRAPH MODE
# Using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
@ -101,11 +107,11 @@ def test_cg_against_numpy(dtype, shape):
b = onp.random.random(shape[:1]).astype(dtype)
expected = onp.linalg.solve(a, b)
# using PYNATIVE MODE
# Using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
# using GRAPH MODE
# Using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
@ -146,11 +152,11 @@ def test_cg_against_scipy_graph(tensor_type, dtype, tol, shape, preconditioner,
b = Tensor(b)
m = to_tensor((m, tensor_type)) if m is not None else m
# using PYNATIVE MODE
# Using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = Net()(a, b, m, maxiter, tol)
# using GRAPH MODE
# Using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = Net()(a, b, m, maxiter, tol)
@ -339,33 +345,39 @@ def test_gmres_against_scipy_level1(n, dtype, error, preconditioner, solve_metho
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [3, 7])
@pytest.mark.parametrize('dtype,error', [(onp.float64, 1e-5), (onp.float32, 1e-4)])
@pytest.mark.parametrize('tensor_type, dtype, error', [('Tensor', onp.float64, 1e-5), ('Tensor', onp.float32, 1e-4),
('CSRTensor', onp.float32, 1e-4)])
@pytest.mark.parametrize('restart', [1, 2])
@pytest.mark.parametrize('maxiter', [1, 2])
@pytest.mark.parametrize('preconditioner', ['identity', 'exact', 'random'])
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner, solve_method):
def test_gmres_against_scipy(n, tensor_type, dtype, error, restart, maxiter, preconditioner, solve_method):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
Expectation: the result match scipy
"""
if not _is_valid_platform(tensor_type):
return
onp.random.seed(0)
a = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype)
x0 = onp.zeros_like(b).astype(dtype)
M = _fetch_preconditioner(preconditioner, a)
m = _fetch_preconditioner(preconditioner, a)
tol = float(onp.finfo(dtype=dtype).eps)
atol = tol
if preconditioner == 'random':
restart = n
maxiter = None
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, atol=atol)
# PyNative Mode
context.set_context(mode=context.PYNATIVE_MODE)
M = Tensor(M) if M is not None else M
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
M=M, atol=atol, solve_method=solve_method)
a = to_tensor((a, tensor_type))
b = Tensor(b)
x0 = Tensor(x0)
m = to_tensor((m, tensor_type)) if m is not None else m
ms_output, _ = msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart,
maxiter=maxiter, M=m, atol=atol, solve_method=solve_method)
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
@ -374,32 +386,57 @@ def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner,
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [3])
@pytest.mark.parametrize('dtype,error', [(onp.float32, 1e-4)])
@pytest.mark.parametrize('tensor_type, dtype, error', [('Tensor', onp.float64, 1e-5), ('Tensor', onp.float32, 1e-4),
('CSRTensor', onp.float32, 1e-4)])
@pytest.mark.parametrize('preconditioner', ['random'])
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
def test_gmres_against_graph_scipy(n, dtype, error, preconditioner, solve_method):
def test_gmres_against_graph_scipy(n, tensor_type, dtype, error, preconditioner, solve_method):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
Expectation: the result match scipy in graph
"""
if not _is_valid_platform(tensor_type):
return
# Input CSRTensor of gmres in mindspore graph mode is not supported, just ignored it.
if tensor_type == "CSRTensor":
return
class TestNet(nn.Cell):
def __init__(self, solve_method):
super(TestNet, self).__init__()
self.solve_method = solve_method
def construct(self, a, b, x0, tol, restart, maxiter, m, atol):
return msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m,
atol=atol, solve_method=self.solve_method)
onp.random.seed(0)
a = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype)
x0 = onp.zeros_like(b).astype(dtype)
M = _fetch_preconditioner(preconditioner, a)
m = _fetch_preconditioner(preconditioner, a)
tol = float(onp.finfo(dtype=dtype).eps)
atol = tol
restart = n
maxiter = None
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, atol=atol)
# Graph Mode
context.set_context(mode=context.GRAPH_MODE)
M = Tensor(M) if M is not None else M
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
M=M, atol=atol, solve_method=solve_method)
a = to_tensor((a, tensor_type))
b = Tensor(b)
x0 = Tensor(x0)
m = to_tensor((m, tensor_type)) if m is not None else m
# Not in graph's construct
ms_output, _ = msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter,
M=m, atol=atol)
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
# With in graph's construct
ms_net_output, _ = TestNet(solve_method)(a, b, x0, tol, restart, maxiter, m, atol)
assert onp.allclose(scipy_output, ms_net_output.asnumpy(), rtol=error, atol=error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training