move test cases

This commit is contained in:
zhujingxuan 2021-11-19 15:10:21 +08:00
parent 85b11671dd
commit 3f2479aca0
2 changed files with 84 additions and 118 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Sparse linear algebra submodule""" """Sparse linear algebra submodule"""
from ... import nn, Tensor, ops, ms_function from ... import nn, Tensor, ms_function
from ... import numpy as mnp from ... import numpy as mnp
from ...ops import functional as F from ...ops import functional as F
from ..linalg import solve_triangular from ..linalg import solve_triangular
@ -25,9 +25,7 @@ def gram_schmidt(Q, q):
""" """
do GramSchmidt process to normalize vector v do GramSchmidt process to normalize vector v
""" """
# transpose is not support float64 yet, h = mnp.dot(Q.T, q)
# so the following code is the same as h = mnp.dot(Q.T, q)
h = ops.MatMul(True, False)(Q, q.reshape((q.shape[0], 1))).flatten()
Qh = mnp.dot(Q, h) Qh = mnp.dot(Q, h)
q = q - Qh q = q - Qh
return q, h return q, h
@ -38,9 +36,8 @@ def arnoldi_iteration(k, A, M, V, H):
v_ = V[..., k] v_ = V[..., k]
v = M(A(v_)) v = M(A(v_))
v, h = gram_schmidt(V, v) v, h = gram_schmidt(V, v)
eps_v = _eps(v)
_, v_norm_0 = _safe_normalize(v) _, v_norm_0 = _safe_normalize(v)
tol = eps_v * v_norm_0 tol = _eps(v) * v_norm_0
unit_v, v_norm_1 = _safe_normalize(v, tol) unit_v, v_norm_1 = _safe_normalize(v, tol)
V[..., k + 1] = unit_v V[..., k + 1] = unit_v
h[k + 1] = v_norm_1 h[k + 1] = v_norm_1
@ -102,9 +99,6 @@ class GivensRotation(nn.Cell):
return R_row, givens return R_row, givens
givens_rotation = GivensRotation()
class BatchedGmres(nn.Cell): class BatchedGmres(nn.Cell):
""" """
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
@ -151,35 +145,65 @@ class BatchedGmres(nn.Cell):
return x return x
def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func): class IterativeGmres(nn.Cell):
""" """
Single iteration for Gmres Algorithm with restart Implements a iterative GMRES. While building the ``restart``-dimensional
Krylov subspace iteratively using Givens Rotation method, the algorithm
constructs a Triangular matrix R which could be more easily solved.
""" """
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
dtype = mnp.result_type(b)
# use eye() to avoid constructing a singular matrix in case of early
# termination
R = mnp.eye(restart, restart + 1, dtype=dtype)
givens = mnp.zeros((restart, 2), dtype=dtype)
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
beta_vec[0] = r_norm
k = 0 def __init__(self, A, M):
err = r_norm super(IterativeGmres, self).__init__()
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)): self.A = A
V, H, _ = arnoldi_iteration(k, A_mat_func, M_mat_func, V, R) self.M = M
R[k, :], givens = givens_rotation(H[k, :], givens, k) self.givens_rotation = GivensRotation()
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
err = mnp.absolute(beta_vec[k + 1])
k = k + 1
y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True) def construct(self, b, x0, tol, atol, restart, maxiter):
dx = mnp.dot(V[:, :-1], y) A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
x = x0 + dx _, b_norm = _safe_normalize(b)
r = M_mat_func(b - A_mat_func(x)) atol = mnp.maximum(tol * b_norm, atol)
r, r_norm = _safe_normalize(r)
return x, r, r_norm Mb = M(b)
_, Mb_norm = _safe_normalize(Mb)
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
r = M(b - A(x0))
r, r_norm = _safe_normalize(r)
iters = _INT_ZERO
while iters < maxiter and r_norm > atol:
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
dtype = mnp.result_type(b)
# use eye() to avoid constructing a singular matrix in case of early
# termination
R = mnp.eye(restart, restart + 1, dtype=dtype)
givens = mnp.zeros((restart, 2), dtype=dtype)
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
beta_vec[0] = r_norm
k = _INT_ZERO
err = r_norm
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
V, H, _ = arnoldi_iteration(k, A, M, V, R)
R[k, :], givens = self.givens_rotation(H[k, :], givens, k)
beta_vec = rotate_vectors(
beta_vec, k, givens[k, 0], givens[k, 1])
err = mnp.absolute(beta_vec[k + 1])
k += 1
y = solve_triangular(
R[:, :-1], beta_vec[:-1], trans='T', lower=True)
dx = mnp.dot(V[:, :-1], y)
x = x0 + dx
r = M(b - A(x))
r, r_norm = _safe_normalize(r)
x0 = x
iters += 1
return x0
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
@ -259,44 +283,17 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
maxiter = 10 * size # copied from scipy maxiter = 10 * size # copied from scipy
if restart > size: if restart > size:
restart = size restart = size
x = x0
if M is None:
def identity(x):
return x
M = identity
if solve_method == 'incremental': if solve_method == 'incremental':
if M is None: x = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter)
def M_mat_func(x):
return x
elif not callable(M):
def M_mat_func(x):
return mnp.dot(M, x)
else:
M_mat_func = M
if not callable(A):
def A_mat_func(x):
return mnp.dot(A, x)
else:
A_mat_func = A
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, atol)
Mb = M_mat_func(b)
_, Mb_norm = _safe_normalize(Mb)
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
# iterative gmres
r = M_mat_func(b - A_mat_func(x0))
r, r_norm = _safe_normalize(r)
k = 0
while k < maxiter and r_norm > atol:
x, r, r_norm = gmres_iter(
A_mat_func, b, x, r, r_norm, ptol, restart, M_mat_func)
k += 1
elif solve_method == 'batched': elif solve_method == 'batched':
if M is None: x = BatchedGmres(A, M)(b, x0, tol, atol, restart, maxiter)
def identity(x):
return x
M = identity
x = BatchedGmres(A, M)(b, x, tol, atol, restart, maxiter)
else: else:
raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}." raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}."
.format(solve_method)) .format(solve_method))

View File

@ -26,13 +26,6 @@ from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matr
onp.random.seed(0) onp.random.seed(0)
def gmres_compare_with_scipy(A, b, x):
gmres_x, _ = msp.sparse.linalg.gmres(Tensor(A), Tensor(b), Tensor(
x), tol=1e-07, atol=0, solve_method='incremental')
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x, tol=1e-07, atol=0)
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=5)
def _fetch_preconditioner(preconditioner, A): def _fetch_preconditioner(preconditioner, A):
""" """
Returns one of various preconditioning matrices depending on the identifier Returns one of various preconditioning matrices depending on the identifier
@ -117,12 +110,21 @@ def test_cg_against_numpy(dtype, shape):
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw) onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
def gmres_compare_with_scipy_incremental(A, b, x, M):
gmres_x, _ = msp.sparse.linalg.gmres(Tensor(A), Tensor(b), Tensor(
x), tol=1e-07, atol=0, solve_method='incremental')
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x, tol=1e-07, atol=0)
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=5)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5]) @pytest.mark.parametrize('n', [5])
@pytest.mark.parametrize('dtype', [onp.float64]) @pytest.mark.parametrize('dtype', [onp.float64])
def test_gmres_incremental_against_scipy_cpu(n, dtype): @pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
def test_gmres_incremental_against_scipy(n, dtype, preconditioner):
""" """
Feature: ALL TO ALL Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1] Description: test cases for [N x N] X [N X 1]
@ -130,27 +132,10 @@ def test_gmres_incremental_against_scipy_cpu(n, dtype):
""" """
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
# add Identity matrix to make matrix A non-singular # add Identity matrix to make matrix A non-singular
A = onp.random.rand(n, n).astype(dtype) A = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype) b = onp.random.rand(n).astype(dtype)
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype)) M = _fetch_preconditioner(preconditioner, A)
gmres_compare_with_scipy_incremental(A, b, onp.zeros_like(b).astype(dtype), M)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_gmres_incremental_against_scipy_cpu_graph(n, dtype):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
Expectation: the result match scipy
"""
context.set_context(mode=context.GRAPH_MODE)
# add Identity matrix to make matrix A non-singular
A = onp.random.rand(n, n).astype(dtype)
b = onp.random.rand(n).astype(dtype)
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
@pytest.mark.level0 @pytest.mark.level0
@ -158,25 +143,8 @@ def test_gmres_incremental_against_scipy_cpu_graph(n, dtype):
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5]) @pytest.mark.parametrize('n', [5])
@pytest.mark.parametrize('dtype', [onp.float64]) @pytest.mark.parametrize('dtype', [onp.float64])
def test_gmres_incremental_against_scipy_gpu(n, dtype): @pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
""" def test_gmres_incremental_against_scipy_graph(n, dtype, preconditioner):
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
Expectation: the result match scipy
"""
context.set_context(mode=context.PYNATIVE_MODE)
# add Identity matrix to make matrix A non-singular
A = onp.random.rand(n, n).astype(dtype)
b = onp.random.rand(n).astype(dtype)
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_gmres_incremental_against_scipy_gpu_graph(n, dtype):
""" """
Feature: ALL TO ALL Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1] Description: test cases for [N x N] X [N X 1]
@ -184,9 +152,10 @@ def test_gmres_incremental_against_scipy_gpu_graph(n, dtype):
""" """
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
# add Identity matrix to make matrix A non-singular # add Identity matrix to make matrix A non-singular
A = onp.random.rand(n, n).astype(dtype) A = create_full_rank_matrix((n, n), dtype)
b = onp.random.rand(n).astype(dtype) b = onp.random.rand(n).astype(dtype)
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype)) M = _fetch_preconditioner(preconditioner, A)
gmres_compare_with_scipy_incremental(A, b, onp.zeros_like(b).astype(dtype), M)
@pytest.mark.level0 @pytest.mark.level0