!26159 scipy gmres batched interface

Merge pull request !26159 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2021-11-17 01:47:21 +00:00 committed by Gitee
commit 2e49770c01
7 changed files with 190 additions and 93 deletions

View File

@ -14,12 +14,14 @@
# ============================================================================
"""Scipy-like interfaces in mindspore."""
from . import optimize, linalg
from .optimize import *
from .linalg import *
from . import optimize, sparse, linalg
from .optimize import minimize, line_search
from .sparse import cg, gmres
from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve
__all__ = []
__all__.extend(optimize.__all__)
__all__.extend(sparse.__all__)
__all__.extend(linalg.__all__)
__all__.sort()

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sparse linear algebra submodule"""
"""Sparse submodule"""
from . import linalg
from .linalg import cg, gmres
__all__ = ["cg", "gmres"]

View File

@ -17,44 +17,36 @@ from ... import nn, Tensor, ops, ms_function
from ... import numpy as mnp
from ...ops import functional as F
from ..linalg import solve_triangular
from ..utils import _INT_ZERO, _normalize_matvec, _INT_ONE, _safe_normalize, _SafeNormalize
from ..linalg import cho_factor, cho_solve
from ..utils import _INT_ZERO, _INT_ONE, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
class ArnoldiIteration(nn.Cell):
""" do the Arnoldi iteration"""
def gram_schmidt(Q, q):
"""
do GramSchmidt process to normalize vector v
"""
# transpose is not support float64 yet,
# so the following code is the same as h = mnp.dot(Q.T, q)
h = ops.MatMul(True, False)(Q, q.reshape((q.shape[0], 1))).flatten()
Qh = mnp.dot(Q, h)
q = q - Qh
return q, h
def __init__(self):
super(ArnoldiIteration, self).__init__()
self.safe_normalize = _SafeNormalize()
self.eps = ops.Eps()
self.sqrt_2 = F.pows(Tensor(2.0), 1/2.0)
self.matmul_t = ops.MatMul(True, False)
def construct(self, k, V, v, H):
v, h = self._gram_schmidt(V, v)
eps_v = self.eps(v[(0,) * v.ndim])
_, v_norm_0 = self.safe_normalize(v)
tol = eps_v * v_norm_0
unit_v, v_norm_1 = self.safe_normalize(v, tol)
V[..., k + 1] = unit_v
h[k + 1] = v_norm_1
H[k, :] = h
return V, H
def _gram_schmidt(self, Q, q):
"""
do GramSchmidt process to normalize vector v
"""
# transpose is not support float64 yet,
# so the following code is the same as h = mnp.dot(Q.T, q)
h = self.matmul_t(Q, q.reshape((q.shape[0], 1))).flatten()
Qh = mnp.dot(Q, h)
q = q - Qh
return q, h
def arnoldi_iteration(k, A, M, V, H):
""" Performs a single (the k'th) step of the Arnoldi process."""
v_ = V[..., k]
v = M(A(v_))
v, h = gram_schmidt(V, v)
eps_v = _eps(v)
_, v_norm_0 = _safe_normalize(v)
tol = eps_v * v_norm_0
unit_v, v_norm_1 = _safe_normalize(v, tol)
V[..., k + 1] = unit_v
h[k + 1] = v_norm_1
H[k, :] = h
breakdown = v_norm_1 == 0
return V, H, breakdown
@ms_function
@ -110,10 +102,55 @@ class GivensRotation(nn.Cell):
return R_row, givens
kth_arnoldi_iteration = ArnoldiIteration()
givens_rotation = GivensRotation()
class BatchedGmres(nn.Cell):
"""
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
This implementation solves a dense linear problem instead of building
a QR factorization during the Arnoldi process.
"""
def __init__(self, A, M):
super(BatchedGmres, self).__init__()
self.A = A
self.M = M
def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None):
A = _normalize_matvec(self.A)
M = _normalize_matvec(self.M)
dtype = b.dtype
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype)
residual = M(b - A(x0))
unit_residual, residual_norm = _safe_normalize(residual)
k = _INT_ZERO
x = x0
while k < maxiter and residual_norm > atol:
pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),)
V = mnp.pad(unit_residual[..., None], pad_width=pad_width)
H = mnp.eye(restart, restart + 1, dtype=dtype)
k_iter = _INT_ZERO
breakdown = _to_tensor(False)
while k_iter < restart and mnp.logical_not(breakdown):
V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H)
k_iter += 1
beta_vec = mnp.zeros((restart + 1,), dtype=dtype)
beta_vec[0] = residual_norm
a2 = mnp.dot(H, H.T)
b2 = mnp.dot(H, beta_vec)
c, lower = cho_factor(a2, lower=False)
factor = (c, lower)
y = cho_solve(factor, b2)
dx = mnp.dot(V[..., :-1], y)
x = x + dx
residual = b - A(x)
unit_residual, residual_norm = _safe_normalize(residual)
k += 1
return x
def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func):
"""
Single iteration for Gmres Algorithm with restart
@ -130,8 +167,7 @@ def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func):
k = 0
err = r_norm
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
v = M_mat_func(A_mat_func(V[:, k]))
V, H = kth_arnoldi_iteration(k, V, v, R)
V, H, _ = arnoldi_iteration(k, A_mat_func, M_mat_func, V, R)
R[k, :], givens = givens_rotation(H[k, :], givens, k)
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
err = mnp.absolute(beta_vec[k + 1])
@ -186,7 +222,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
Krylov space starting from the solution found at the last iteration. If GMRES
halts or is very slow, decreasing this parameter may help.
Default is infinite.
M (Tensor): Preconditioner for A. The preconditioner should approximate the
M (Tensor or function): Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
@ -206,67 +242,66 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.numpy as mnp
>>> from mindspore.scipy.sparse import csc_matrix
>>> from mindspore.scipy.sparse.linalg import gmres
>>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=np.float32)
>>> b = Tensor(onp.array([2, 4, -1], dtype=np.float32))
>>> A = Tensor(mnp.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=mnp.float32))
>>> b = Tensor(onp.array([2, 4, -1], dtype=mnp.float32))
>>> x, exitCode = gmres(A, b)
>>> print(exitCode) # 0 indicates successful convergence
0
>>> np.allclose(A.matvec(x).asnumpy(), b.asnumpy())
>>> onp.allclose(mnp.dot(A,x).asnumpy(), b.asnumpy())
True
"""
if x0 is None:
x0 = mnp.zeros_like(b)
if M is None:
def M_mat_func(x):
return x
elif not callable(M):
def M_mat_func(x):
return mnp.dot(M, x)
else:
M_mat_func = M
if not callable(A):
def A_mat_func(x):
return mnp.dot(A, x)
else:
A_mat_func = A
size = b.size
if maxiter is None:
maxiter = 10 * size # copied from scipy
restart = min(restart, size)
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, atol)
Mb = M_mat_func(b)
_, Mb_norm = _safe_normalize(Mb)
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
if restart > size:
restart = size
x = x0
if solve_method == 'incremental':
if M is None:
def M_mat_func(x):
return x
elif not callable(M):
def M_mat_func(x):
return mnp.dot(M, x)
else:
M_mat_func = M
if not callable(A):
def A_mat_func(x):
return mnp.dot(A, x)
else:
A_mat_func = A
_, b_norm = _safe_normalize(b)
atol = mnp.maximum(tol * b_norm, atol)
Mb = M_mat_func(b)
_, Mb_norm = _safe_normalize(Mb)
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
# iterative gmres
r = M_mat_func(b - A_mat_func(x0))
r, r_norm = _safe_normalize(r)
x = x0
k = 0
while k < maxiter and r_norm > atol:
x, r, r_norm = gmres_iter(
A_mat_func, b, x, r, r_norm, ptol, restart, M_mat_func)
k += 1
_, x_norm = _safe_normalize(x)
info = mnp.where(mnp.isnan(x_norm), -1, 0)
elif solve_method == 'batched':
raise NotImplementedError("batched method not implemented yet")
if M is None:
def identity(x):
return x
M = identity
x = BatchedGmres(A, M)(b, x, tol, atol, restart, maxiter)
else:
raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}."
.format(solve_method))
_, x_norm = _safe_normalize(x)
info = mnp.where(mnp.isnan(x_norm), _INT_NEG_ONE, _INT_ZERO)
return x, info
@ -344,7 +379,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None) -> (Tensor, N
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
maxiter (int): Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M (Tensor): Preconditioner for A. The preconditioner should approximate the
M (Tensor or function): Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.

View File

@ -15,7 +15,7 @@
"""internal utility functions"""
import numpy as onp
from .. import nn, ops
from ..numpy import where, isnan, zeros_like, dot
from ..numpy import where, zeros_like, dot, greater
from ..ops import functional as F
from ..common import Tensor
from ..common import dtype as mstype
@ -25,6 +25,7 @@ from ..ops.primitive import constexpr
from .._c_expression import typing
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False)
_eps_net = ops.Eps()
def _convert_64_to_32(tensor):
@ -69,13 +70,16 @@ def _to_scalar(arr):
raise ValueError("{} are not supported.".format(type(arr)))
def _eps(x):
return _eps_net(x[(0,) * x.ndim])
class _SafeNormalize(nn.Cell):
"""Normalize method that cast very small results to zero."""
def __init__(self):
"""Initialize LineSearch."""
super(_SafeNormalize, self).__init__()
self.eps = ops.Eps()
def construct(self, x, threshold=None):
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
@ -83,11 +87,13 @@ class _SafeNormalize(nn.Cell):
if threshold is None:
if x.dtype in mstype.float_type:
# pick the first element of x to get the eps
threshold = self.eps(x[(0,) * x.ndim])
threshold = _eps(x)
else:
threshold = 0
normalized_x = where(norm > threshold, x / norm, zeros_like(x))
normalized_x = where(isnan(normalized_x), 0, normalized_x)
use_norm = greater(norm, threshold)
x_norm = x / norm
normalized_x = where(use_norm, x_norm, zeros_like(x))
norm = where(use_norm, norm, zeros_like(norm))
return normalized_x, norm
@ -95,6 +101,9 @@ _safe_normalize = _SafeNormalize()
_INT_ZERO = _to_tensor(0)
_INT_ONE = _to_tensor(1)
_INT_NEG_ONE = _to_tensor(-1)
_FLOAT_ONE = _to_tensor(1.0)
_FLOAT_TWO = _to_tensor(2.0, dtype=float)
_BOOL_TRUE = _to_tensor(True)
_BOOL_FALSE = _to_tensor(False)

View File

@ -70,9 +70,9 @@ def test_bfgs(dtype, func_x0):
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)).x
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS').x
match_array(ms_res.asnumpy(), scipy_res, error=5)
options=dict(maxiter=None, gtol=1e-6))
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS')
match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res))
@pytest.mark.level0

View File

@ -16,13 +16,13 @@
import pytest
import numpy as onp
import scipy as osp
from scipy.sparse.linalg import cg as osp_cg
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from mindspore.scipy.sparse.linalg import cg as msp_cg
from .utils import create_sym_pos_matrix
from .utils import create_sym_pos_matrix, create_full_rank_matrix
onp.random.seed(0)
@ -109,3 +109,53 @@ def test_cg_against_numpy(dtype, shape):
kw = {"atol": 1e-5, "rtol": 1e-5}
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [4, 5, 6])
@pytest.mark.parametrize('dtype', [onp.float64])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 2])
def test_pynative_batched_gmres_against_scipy(n, dtype, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
"""
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
M = _fetch_preconditioner(preconditioner, a)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
M = Tensor(M) if M is not None else M
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6,
solve_method='batched')
assert onp.allclose(msp_x.asnumpy(), osp_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [5, 6])
@pytest.mark.parametrize('dtype', [onp.float64])
def test_graph_batched_gmres_against_scipy(n, dtype):
"""
Feature: ALL TO ALL
Description: test cases for gmres
Expectation: the result match scipy
"""
context.set_context(mode=context.GRAPH_MODE)
shape = (n, n)
a = create_full_rank_matrix(shape, dtype)
b = onp.random.rand(n).astype(dtype=dtype)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
osp_x, _ = osp.sparse.linalg.gmres(a, b, atol=0.0)
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, atol=0.0, solve_method='batched')
assert onp.allclose(msp_x.asnumpy(), osp_x)

View File

@ -30,7 +30,7 @@ def to_tensor(obj, dtype=None):
return res
def match_array(actual, expected, error=0):
def match_array(actual, expected, error=0, err_msg=''):
if isinstance(actual, int):
actual = onp.asarray(actual)
@ -38,9 +38,9 @@ def match_array(actual, expected, error=0):
expected = onp.asarray(expected)
if error > 0:
onp.testing.assert_almost_equal(actual, expected, decimal=error)
onp.testing.assert_almost_equal(actual, expected, decimal=error, err_msg=err_msg)
else:
onp.testing.assert_equal(actual, expected)
onp.testing.assert_equal(actual, expected, err_msg=err_msg)
def create_full_rank_matrix(shape, dtype):