forked from mindspore-Ecosystem/mindspore
move test cases
This commit is contained in:
parent
85b11671dd
commit
3f2479aca0
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Sparse linear algebra submodule"""
|
"""Sparse linear algebra submodule"""
|
||||||
from ... import nn, Tensor, ops, ms_function
|
from ... import nn, Tensor, ms_function
|
||||||
from ... import numpy as mnp
|
from ... import numpy as mnp
|
||||||
from ...ops import functional as F
|
from ...ops import functional as F
|
||||||
from ..linalg import solve_triangular
|
from ..linalg import solve_triangular
|
||||||
|
@ -25,9 +25,7 @@ def gram_schmidt(Q, q):
|
||||||
"""
|
"""
|
||||||
do Gram–Schmidt process to normalize vector v
|
do Gram–Schmidt process to normalize vector v
|
||||||
"""
|
"""
|
||||||
# transpose is not support float64 yet,
|
h = mnp.dot(Q.T, q)
|
||||||
# so the following code is the same as h = mnp.dot(Q.T, q)
|
|
||||||
h = ops.MatMul(True, False)(Q, q.reshape((q.shape[0], 1))).flatten()
|
|
||||||
Qh = mnp.dot(Q, h)
|
Qh = mnp.dot(Q, h)
|
||||||
q = q - Qh
|
q = q - Qh
|
||||||
return q, h
|
return q, h
|
||||||
|
@ -38,9 +36,8 @@ def arnoldi_iteration(k, A, M, V, H):
|
||||||
v_ = V[..., k]
|
v_ = V[..., k]
|
||||||
v = M(A(v_))
|
v = M(A(v_))
|
||||||
v, h = gram_schmidt(V, v)
|
v, h = gram_schmidt(V, v)
|
||||||
eps_v = _eps(v)
|
|
||||||
_, v_norm_0 = _safe_normalize(v)
|
_, v_norm_0 = _safe_normalize(v)
|
||||||
tol = eps_v * v_norm_0
|
tol = _eps(v) * v_norm_0
|
||||||
unit_v, v_norm_1 = _safe_normalize(v, tol)
|
unit_v, v_norm_1 = _safe_normalize(v, tol)
|
||||||
V[..., k + 1] = unit_v
|
V[..., k + 1] = unit_v
|
||||||
h[k + 1] = v_norm_1
|
h[k + 1] = v_norm_1
|
||||||
|
@ -102,9 +99,6 @@ class GivensRotation(nn.Cell):
|
||||||
return R_row, givens
|
return R_row, givens
|
||||||
|
|
||||||
|
|
||||||
givens_rotation = GivensRotation()
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedGmres(nn.Cell):
|
class BatchedGmres(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
||||||
|
@ -151,35 +145,65 @@ class BatchedGmres(nn.Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func):
|
class IterativeGmres(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Single iteration for Gmres Algorithm with restart
|
Implements a iterative GMRES. While building the ``restart``-dimensional
|
||||||
|
Krylov subspace iteratively using Givens Rotation method, the algorithm
|
||||||
|
constructs a Triangular matrix R which could be more easily solved.
|
||||||
"""
|
"""
|
||||||
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
|
|
||||||
dtype = mnp.result_type(b)
|
|
||||||
# use eye() to avoid constructing a singular matrix in case of early
|
|
||||||
# termination
|
|
||||||
R = mnp.eye(restart, restart + 1, dtype=dtype)
|
|
||||||
givens = mnp.zeros((restart, 2), dtype=dtype)
|
|
||||||
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
|
|
||||||
beta_vec[0] = r_norm
|
|
||||||
|
|
||||||
k = 0
|
def __init__(self, A, M):
|
||||||
err = r_norm
|
super(IterativeGmres, self).__init__()
|
||||||
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
self.A = A
|
||||||
V, H, _ = arnoldi_iteration(k, A_mat_func, M_mat_func, V, R)
|
self.M = M
|
||||||
R[k, :], givens = givens_rotation(H[k, :], givens, k)
|
self.givens_rotation = GivensRotation()
|
||||||
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
|
|
||||||
err = mnp.absolute(beta_vec[k + 1])
|
|
||||||
k = k + 1
|
|
||||||
|
|
||||||
y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True)
|
def construct(self, b, x0, tol, atol, restart, maxiter):
|
||||||
dx = mnp.dot(V[:, :-1], y)
|
A = _normalize_matvec(self.A)
|
||||||
|
M = _normalize_matvec(self.M)
|
||||||
|
|
||||||
x = x0 + dx
|
_, b_norm = _safe_normalize(b)
|
||||||
r = M_mat_func(b - A_mat_func(x))
|
atol = mnp.maximum(tol * b_norm, atol)
|
||||||
r, r_norm = _safe_normalize(r)
|
|
||||||
return x, r, r_norm
|
Mb = M(b)
|
||||||
|
_, Mb_norm = _safe_normalize(Mb)
|
||||||
|
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
|
||||||
|
|
||||||
|
r = M(b - A(x0))
|
||||||
|
r, r_norm = _safe_normalize(r)
|
||||||
|
|
||||||
|
iters = _INT_ZERO
|
||||||
|
while iters < maxiter and r_norm > atol:
|
||||||
|
V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),))
|
||||||
|
dtype = mnp.result_type(b)
|
||||||
|
# use eye() to avoid constructing a singular matrix in case of early
|
||||||
|
# termination
|
||||||
|
R = mnp.eye(restart, restart + 1, dtype=dtype)
|
||||||
|
givens = mnp.zeros((restart, 2), dtype=dtype)
|
||||||
|
beta_vec = mnp.zeros((restart + 1), dtype=dtype)
|
||||||
|
beta_vec[0] = r_norm
|
||||||
|
|
||||||
|
k = _INT_ZERO
|
||||||
|
err = r_norm
|
||||||
|
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
||||||
|
V, H, _ = arnoldi_iteration(k, A, M, V, R)
|
||||||
|
R[k, :], givens = self.givens_rotation(H[k, :], givens, k)
|
||||||
|
beta_vec = rotate_vectors(
|
||||||
|
beta_vec, k, givens[k, 0], givens[k, 1])
|
||||||
|
err = mnp.absolute(beta_vec[k + 1])
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
y = solve_triangular(
|
||||||
|
R[:, :-1], beta_vec[:-1], trans='T', lower=True)
|
||||||
|
dx = mnp.dot(V[:, :-1], y)
|
||||||
|
|
||||||
|
x = x0 + dx
|
||||||
|
r = M(b - A(x))
|
||||||
|
r, r_norm = _safe_normalize(r)
|
||||||
|
x0 = x
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
return x0
|
||||||
|
|
||||||
|
|
||||||
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||||
|
@ -259,44 +283,17 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||||
maxiter = 10 * size # copied from scipy
|
maxiter = 10 * size # copied from scipy
|
||||||
if restart > size:
|
if restart > size:
|
||||||
restart = size
|
restart = size
|
||||||
x = x0
|
|
||||||
|
if M is None:
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
M = identity
|
||||||
|
|
||||||
if solve_method == 'incremental':
|
if solve_method == 'incremental':
|
||||||
if M is None:
|
x = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter)
|
||||||
def M_mat_func(x):
|
|
||||||
return x
|
|
||||||
elif not callable(M):
|
|
||||||
def M_mat_func(x):
|
|
||||||
return mnp.dot(M, x)
|
|
||||||
else:
|
|
||||||
M_mat_func = M
|
|
||||||
|
|
||||||
if not callable(A):
|
|
||||||
def A_mat_func(x):
|
|
||||||
return mnp.dot(A, x)
|
|
||||||
else:
|
|
||||||
A_mat_func = A
|
|
||||||
|
|
||||||
_, b_norm = _safe_normalize(b)
|
|
||||||
atol = mnp.maximum(tol * b_norm, atol)
|
|
||||||
|
|
||||||
Mb = M_mat_func(b)
|
|
||||||
_, Mb_norm = _safe_normalize(Mb)
|
|
||||||
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
|
|
||||||
# iterative gmres
|
|
||||||
r = M_mat_func(b - A_mat_func(x0))
|
|
||||||
r, r_norm = _safe_normalize(r)
|
|
||||||
k = 0
|
|
||||||
while k < maxiter and r_norm > atol:
|
|
||||||
x, r, r_norm = gmres_iter(
|
|
||||||
A_mat_func, b, x, r, r_norm, ptol, restart, M_mat_func)
|
|
||||||
k += 1
|
|
||||||
elif solve_method == 'batched':
|
elif solve_method == 'batched':
|
||||||
if M is None:
|
x = BatchedGmres(A, M)(b, x0, tol, atol, restart, maxiter)
|
||||||
def identity(x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
M = identity
|
|
||||||
x = BatchedGmres(A, M)(b, x, tol, atol, restart, maxiter)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}."
|
raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}."
|
||||||
.format(solve_method))
|
.format(solve_method))
|
||||||
|
|
|
@ -26,13 +26,6 @@ from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matr
|
||||||
onp.random.seed(0)
|
onp.random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
def gmres_compare_with_scipy(A, b, x):
|
|
||||||
gmres_x, _ = msp.sparse.linalg.gmres(Tensor(A), Tensor(b), Tensor(
|
|
||||||
x), tol=1e-07, atol=0, solve_method='incremental')
|
|
||||||
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x, tol=1e-07, atol=0)
|
|
||||||
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=5)
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_preconditioner(preconditioner, A):
|
def _fetch_preconditioner(preconditioner, A):
|
||||||
"""
|
"""
|
||||||
Returns one of various preconditioning matrices depending on the identifier
|
Returns one of various preconditioning matrices depending on the identifier
|
||||||
|
@ -117,12 +110,21 @@ def test_cg_against_numpy(dtype, shape):
|
||||||
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
|
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def gmres_compare_with_scipy_incremental(A, b, x, M):
|
||||||
|
gmres_x, _ = msp.sparse.linalg.gmres(Tensor(A), Tensor(b), Tensor(
|
||||||
|
x), tol=1e-07, atol=0, solve_method='incremental')
|
||||||
|
scipy_x, _ = osp.sparse.linalg.gmres(A, b, x, tol=1e-07, atol=0)
|
||||||
|
onp.testing.assert_almost_equal(scipy_x, gmres_x.asnumpy(), decimal=5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.parametrize('n', [5])
|
@pytest.mark.parametrize('n', [5])
|
||||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||||
def test_gmres_incremental_against_scipy_cpu(n, dtype):
|
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||||
|
def test_gmres_incremental_against_scipy(n, dtype, preconditioner):
|
||||||
"""
|
"""
|
||||||
Feature: ALL TO ALL
|
Feature: ALL TO ALL
|
||||||
Description: test cases for [N x N] X [N X 1]
|
Description: test cases for [N x N] X [N X 1]
|
||||||
|
@ -130,27 +132,10 @@ def test_gmres_incremental_against_scipy_cpu(n, dtype):
|
||||||
"""
|
"""
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
# add Identity matrix to make matrix A non-singular
|
# add Identity matrix to make matrix A non-singular
|
||||||
A = onp.random.rand(n, n).astype(dtype)
|
A = create_full_rank_matrix((n, n), dtype)
|
||||||
b = onp.random.rand(n).astype(dtype)
|
b = onp.random.rand(n).astype(dtype)
|
||||||
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
|
M = _fetch_preconditioner(preconditioner, A)
|
||||||
|
gmres_compare_with_scipy_incremental(A, b, onp.zeros_like(b).astype(dtype), M)
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
@pytest.mark.parametrize('n', [5])
|
|
||||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
|
||||||
def test_gmres_incremental_against_scipy_cpu_graph(n, dtype):
|
|
||||||
"""
|
|
||||||
Feature: ALL TO ALL
|
|
||||||
Description: test cases for [N x N] X [N X 1]
|
|
||||||
Expectation: the result match scipy
|
|
||||||
"""
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
# add Identity matrix to make matrix A non-singular
|
|
||||||
A = onp.random.rand(n, n).astype(dtype)
|
|
||||||
b = onp.random.rand(n).astype(dtype)
|
|
||||||
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -158,25 +143,8 @@ def test_gmres_incremental_against_scipy_cpu_graph(n, dtype):
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.parametrize('n', [5])
|
@pytest.mark.parametrize('n', [5])
|
||||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||||
def test_gmres_incremental_against_scipy_gpu(n, dtype):
|
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||||
"""
|
def test_gmres_incremental_against_scipy_graph(n, dtype, preconditioner):
|
||||||
Feature: ALL TO ALL
|
|
||||||
Description: test cases for [N x N] X [N X 1]
|
|
||||||
Expectation: the result match scipy
|
|
||||||
"""
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
|
||||||
# add Identity matrix to make matrix A non-singular
|
|
||||||
A = onp.random.rand(n, n).astype(dtype)
|
|
||||||
b = onp.random.rand(n).astype(dtype)
|
|
||||||
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
@pytest.mark.parametrize('n', [5])
|
|
||||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
|
||||||
def test_gmres_incremental_against_scipy_gpu_graph(n, dtype):
|
|
||||||
"""
|
"""
|
||||||
Feature: ALL TO ALL
|
Feature: ALL TO ALL
|
||||||
Description: test cases for [N x N] X [N X 1]
|
Description: test cases for [N x N] X [N X 1]
|
||||||
|
@ -184,9 +152,10 @@ def test_gmres_incremental_against_scipy_gpu_graph(n, dtype):
|
||||||
"""
|
"""
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
# add Identity matrix to make matrix A non-singular
|
# add Identity matrix to make matrix A non-singular
|
||||||
A = onp.random.rand(n, n).astype(dtype)
|
A = create_full_rank_matrix((n, n), dtype)
|
||||||
b = onp.random.rand(n).astype(dtype)
|
b = onp.random.rand(n).astype(dtype)
|
||||||
gmres_compare_with_scipy(A, b, onp.zeros_like(b).astype(dtype))
|
M = _fetch_preconditioner(preconditioner, A)
|
||||||
|
gmres_compare_with_scipy_incremental(A, b, onp.zeros_like(b).astype(dtype), M)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
|
Loading…
Reference in New Issue