forked from mindspore-Ecosystem/mindspore
!31635 opt gmres code
Merge pull request !31635 from zhuzhongrui/pub_master
This commit is contained in:
commit
6a07fffcac
|
@ -72,24 +72,13 @@ def _high_precision_cho_solve(a, b, data_type=mstype.float64):
|
|||
return y.astype(data_type)
|
||||
|
||||
|
||||
class BatchedGmres(nn.Cell):
|
||||
def _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
||||
This implementation solves a dense linear problem instead of building
|
||||
a QR factorization during the Arnoldi process.
|
||||
batched gmres: solve the least squares problem from scratch at the end of each GMRES iteration.
|
||||
It does not allow for early termination, but has much less overhead on GPUs.
|
||||
"""
|
||||
|
||||
def __init__(self, A, M):
|
||||
super(BatchedGmres, self).__init__()
|
||||
self.A = A
|
||||
self.M = M
|
||||
|
||||
def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
dtype = b.dtype
|
||||
_, b_norm = _safe_normalize(b)
|
||||
atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype)
|
||||
|
@ -114,29 +103,16 @@ class BatchedGmres(nn.Cell):
|
|||
residual = M(b - A(x))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
k += 1
|
||||
|
||||
return x, F.select(residual_norm > atol, k, _INT_ZERO)
|
||||
|
||||
|
||||
class IterativeGmres(nn.Cell):
|
||||
def _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M):
|
||||
"""
|
||||
Implements a iterative GMRES. While building the ``restart``-dimensional
|
||||
Krylov subspace iteratively using Givens Rotation method, the algorithm
|
||||
constructs a Triangular matrix R which could be more easily solved.
|
||||
incremental gmres: builds a QR decomposition for the Krylov subspace incrementally during
|
||||
the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of
|
||||
the residual norm that allows for early termination within a single "restart".
|
||||
"""
|
||||
|
||||
def __init__(self, A, M):
|
||||
super(IterativeGmres, self).__init__()
|
||||
self.A = A
|
||||
self.M = M
|
||||
|
||||
def construct(self, b, x0, tol, atol, restart, maxiter):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
|
||||
_, b_norm = _safe_normalize(b)
|
||||
atol = mnp.maximum(tol * b_norm, atol)
|
||||
|
||||
|
@ -151,8 +127,8 @@ class IterativeGmres(nn.Cell):
|
|||
while iters < maxiter and r_norm > atol:
|
||||
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
|
||||
dtype = mnp.result_type(b)
|
||||
# use eye() to avoid constructing a singular matrix in case of early
|
||||
# termination
|
||||
# Use eye() to avoid constructing a singular matrix in case of early
|
||||
# Termination
|
||||
R = mnp.eye(restart, restart + 1, dtype=dtype)
|
||||
givens = mnp.zeros((restart, 2), dtype=dtype)
|
||||
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
|
||||
|
@ -162,7 +138,7 @@ class IterativeGmres(nn.Cell):
|
|||
err = r_norm
|
||||
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
||||
V, R, _ = arnoldi_iteration(k, A, M, V, R)
|
||||
# givens rotation
|
||||
# Givens rotation
|
||||
row_k = R[k, :].copy()
|
||||
i = _INT_ZERO
|
||||
while i < k:
|
||||
|
@ -192,10 +168,63 @@ class IterativeGmres(nn.Cell):
|
|||
r, r_norm = _safe_normalize(r)
|
||||
x0 = x
|
||||
iters += 1
|
||||
|
||||
return x0, F.select(r_norm > atol, iters, _INT_ZERO)
|
||||
|
||||
|
||||
class GMRES(nn.Cell):
|
||||
"""
|
||||
Given given A and b, GMRES solves the linear system:
|
||||
|
||||
.. math::
|
||||
A x = b
|
||||
"""
|
||||
|
||||
def __init__(self, A, M, solve_method):
|
||||
super(GMRES, self).__init__()
|
||||
self.A = A
|
||||
self.M = M
|
||||
self.solve_method = solve_method
|
||||
|
||||
def construct(self, b, x0, tol, atol, restart, maxiter):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
x = x0
|
||||
info = _to_tensor(0)
|
||||
if self.solve_method == 'batched':
|
||||
x, info = _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M)
|
||||
elif self.solve_method == "incremental":
|
||||
x, info = _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M)
|
||||
else:
|
||||
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", self.solve_method,
|
||||
".")
|
||||
return x, info
|
||||
|
||||
|
||||
class GMRESV2(nn.Cell):
|
||||
"""
|
||||
This is a new version of GMRES, which contains all parameters in a graph.
|
||||
"""
|
||||
|
||||
def __init__(self, solve_method):
|
||||
super(GMRESV2, self).__init__()
|
||||
self.solve_method = solve_method
|
||||
|
||||
def construct(self, A, b, x0, tol, atol, restart, maxiter, M):
|
||||
A = _normalize_matvec(A)
|
||||
M = _normalize_matvec(M)
|
||||
x = x0
|
||||
info = _to_tensor(0)
|
||||
if self.solve_method == 'batched':
|
||||
x, info = _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M)
|
||||
elif self.solve_method == "incremental":
|
||||
x, info = _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M)
|
||||
else:
|
||||
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", self.solve_method,
|
||||
".")
|
||||
return x, info
|
||||
|
||||
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
|
||||
M=None, callback=None, restrt=None, atol=0.0, callback_type=None, solve_method='batched'):
|
||||
"""
|
||||
|
@ -292,13 +321,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None,
|
|||
_value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo')
|
||||
if restart > size:
|
||||
restart = size
|
||||
|
||||
if solve_method == 'incremental':
|
||||
x, info = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter)
|
||||
elif solve_method == 'batched':
|
||||
x, info = BatchedGmres(A, M)(b, x0, tol, atol, restart, maxiter)
|
||||
if not is_within_graph(A):
|
||||
x, info = GMRES(A, M, solve_method)(b, x0, tol, atol, restart, maxiter)
|
||||
else:
|
||||
_raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", solve_method, ".")
|
||||
x, info = GMRESV2(solve_method)(A, b, x0, tol, atol, restart, maxiter, M)
|
||||
return x, info
|
||||
|
||||
|
||||
|
|
|
@ -43,6 +43,12 @@ def _fetch_preconditioner(preconditioner, a):
|
|||
return M
|
||||
|
||||
|
||||
def _is_valid_platform(tensor_type='Tensor'):
|
||||
if tensor_type == "CSRTensor" and get_platform() != "linux":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -58,7 +64,7 @@ def test_cg_against_scipy(tensor_type, dtype, tol, shape, preconditioner, maxite
|
|||
Description: test cases for cg using function way in pynative/graph mode
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
if tensor_type == "CSRTensor" and get_platform() != "linux":
|
||||
if not _is_valid_platform(tensor_type):
|
||||
return
|
||||
onp.random.seed(0)
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
|
@ -70,11 +76,11 @@ def test_cg_against_scipy(tensor_type, dtype, tol, shape, preconditioner, maxite
|
|||
b = Tensor(b)
|
||||
m = to_tensor((m, tensor_type)) if m is not None else m
|
||||
|
||||
# using PYNATIVE MODE
|
||||
# Using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
# using GRAPH MODE
|
||||
# Using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
|
@ -101,11 +107,11 @@ def test_cg_against_numpy(dtype, shape):
|
|||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
expected = onp.linalg.solve(a, b)
|
||||
|
||||
# using PYNATIVE MODE
|
||||
# Using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
|
||||
|
||||
# using GRAPH MODE
|
||||
# Using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
|
||||
|
||||
|
@ -146,11 +152,11 @@ def test_cg_against_scipy_graph(tensor_type, dtype, tol, shape, preconditioner,
|
|||
b = Tensor(b)
|
||||
m = to_tensor((m, tensor_type)) if m is not None else m
|
||||
|
||||
# using PYNATIVE MODE
|
||||
# Using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
msp_res_dyn = Net()(a, b, m, maxiter, tol)
|
||||
|
||||
# using GRAPH MODE
|
||||
# Using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
msp_res_sta = Net()(a, b, m, maxiter, tol)
|
||||
|
||||
|
@ -339,33 +345,39 @@ def test_gmres_against_scipy_level1(n, dtype, error, preconditioner, solve_metho
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [3, 7])
|
||||
@pytest.mark.parametrize('dtype,error', [(onp.float64, 1e-5), (onp.float32, 1e-4)])
|
||||
@pytest.mark.parametrize('tensor_type, dtype, error', [('Tensor', onp.float64, 1e-5), ('Tensor', onp.float32, 1e-4),
|
||||
('CSRTensor', onp.float32, 1e-4)])
|
||||
@pytest.mark.parametrize('restart', [1, 2])
|
||||
@pytest.mark.parametrize('maxiter', [1, 2])
|
||||
@pytest.mark.parametrize('preconditioner', ['identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
|
||||
def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner, solve_method):
|
||||
def test_gmres_against_scipy(n, tensor_type, dtype, error, restart, maxiter, preconditioner, solve_method):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N X 1]
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
if not _is_valid_platform(tensor_type):
|
||||
return
|
||||
onp.random.seed(0)
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.rand(n).astype(dtype)
|
||||
x0 = onp.zeros_like(b).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
m = _fetch_preconditioner(preconditioner, a)
|
||||
tol = float(onp.finfo(dtype=dtype).eps)
|
||||
atol = tol
|
||||
if preconditioner == 'random':
|
||||
restart = n
|
||||
maxiter = None
|
||||
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
|
||||
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, atol=atol)
|
||||
# PyNative Mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
M = Tensor(M) if M is not None else M
|
||||
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
|
||||
M=M, atol=atol, solve_method=solve_method)
|
||||
a = to_tensor((a, tensor_type))
|
||||
b = Tensor(b)
|
||||
x0 = Tensor(x0)
|
||||
m = to_tensor((m, tensor_type)) if m is not None else m
|
||||
ms_output, _ = msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart,
|
||||
maxiter=maxiter, M=m, atol=atol, solve_method=solve_method)
|
||||
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
|
||||
|
||||
|
||||
|
@ -374,32 +386,57 @@ def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner,
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [3])
|
||||
@pytest.mark.parametrize('dtype,error', [(onp.float32, 1e-4)])
|
||||
@pytest.mark.parametrize('tensor_type, dtype, error', [('Tensor', onp.float64, 1e-5), ('Tensor', onp.float32, 1e-4),
|
||||
('CSRTensor', onp.float32, 1e-4)])
|
||||
@pytest.mark.parametrize('preconditioner', ['random'])
|
||||
@pytest.mark.parametrize('solve_method', ['incremental', 'batched'])
|
||||
def test_gmres_against_graph_scipy(n, dtype, error, preconditioner, solve_method):
|
||||
def test_gmres_against_graph_scipy(n, tensor_type, dtype, error, preconditioner, solve_method):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for [N x N] X [N X 1]
|
||||
Expectation: the result match scipy in graph
|
||||
"""
|
||||
if not _is_valid_platform(tensor_type):
|
||||
return
|
||||
|
||||
# Input CSRTensor of gmres in mindspore graph mode is not supported, just ignored it.
|
||||
if tensor_type == "CSRTensor":
|
||||
return
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
def __init__(self, solve_method):
|
||||
super(TestNet, self).__init__()
|
||||
self.solve_method = solve_method
|
||||
|
||||
def construct(self, a, b, x0, tol, restart, maxiter, m, atol):
|
||||
return msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m,
|
||||
atol=atol, solve_method=self.solve_method)
|
||||
|
||||
onp.random.seed(0)
|
||||
a = create_full_rank_matrix((n, n), dtype)
|
||||
b = onp.random.rand(n).astype(dtype)
|
||||
x0 = onp.zeros_like(b).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
m = _fetch_preconditioner(preconditioner, a)
|
||||
tol = float(onp.finfo(dtype=dtype).eps)
|
||||
atol = tol
|
||||
restart = n
|
||||
maxiter = None
|
||||
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol)
|
||||
scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, atol=atol)
|
||||
# Graph Mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
M = Tensor(M) if M is not None else M
|
||||
ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter,
|
||||
M=M, atol=atol, solve_method=solve_method)
|
||||
a = to_tensor((a, tensor_type))
|
||||
b = Tensor(b)
|
||||
x0 = Tensor(x0)
|
||||
m = to_tensor((m, tensor_type)) if m is not None else m
|
||||
# Not in graph's construct
|
||||
ms_output, _ = msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter,
|
||||
M=m, atol=atol)
|
||||
assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error)
|
||||
|
||||
# With in graph's construct
|
||||
ms_net_output, _ = TestNet(solve_method)(a, b, x0, tol, restart, maxiter, m, atol)
|
||||
assert onp.allclose(scipy_output, ms_net_output.asnumpy(), rtol=error, atol=error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
Loading…
Reference in New Issue