forked from mindspore-Ecosystem/mindspore
!26159 scipy gmres batched interface
Merge pull request !26159 from zhuzhongrui/pub_master
This commit is contained in:
commit
2e49770c01
|
@ -14,12 +14,14 @@
|
|||
# ============================================================================
|
||||
"""Scipy-like interfaces in mindspore."""
|
||||
|
||||
from . import optimize, linalg
|
||||
from .optimize import *
|
||||
from .linalg import *
|
||||
from . import optimize, sparse, linalg
|
||||
from .optimize import minimize, line_search
|
||||
from .sparse import cg, gmres
|
||||
from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(optimize.__all__)
|
||||
__all__.extend(sparse.__all__)
|
||||
__all__.extend(linalg.__all__)
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Sparse linear algebra submodule"""
|
||||
"""Sparse submodule"""
|
||||
from . import linalg
|
||||
from .linalg import cg, gmres
|
||||
|
||||
__all__ = ["cg", "gmres"]
|
||||
|
|
|
@ -17,44 +17,36 @@ from ... import nn, Tensor, ops, ms_function
|
|||
from ... import numpy as mnp
|
||||
from ...ops import functional as F
|
||||
from ..linalg import solve_triangular
|
||||
|
||||
from ..utils import _INT_ZERO, _normalize_matvec, _INT_ONE, _safe_normalize, _SafeNormalize
|
||||
from ..linalg import cho_factor, cho_solve
|
||||
from ..utils import _INT_ZERO, _INT_ONE, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps
|
||||
|
||||
|
||||
class ArnoldiIteration(nn.Cell):
|
||||
""" do the Arnoldi iteration"""
|
||||
def gram_schmidt(Q, q):
|
||||
"""
|
||||
do Gram–Schmidt process to normalize vector v
|
||||
"""
|
||||
# transpose is not support float64 yet,
|
||||
# so the following code is the same as h = mnp.dot(Q.T, q)
|
||||
h = ops.MatMul(True, False)(Q, q.reshape((q.shape[0], 1))).flatten()
|
||||
Qh = mnp.dot(Q, h)
|
||||
q = q - Qh
|
||||
return q, h
|
||||
|
||||
def __init__(self):
|
||||
super(ArnoldiIteration, self).__init__()
|
||||
self.safe_normalize = _SafeNormalize()
|
||||
self.eps = ops.Eps()
|
||||
self.sqrt_2 = F.pows(Tensor(2.0), 1/2.0)
|
||||
self.matmul_t = ops.MatMul(True, False)
|
||||
|
||||
def construct(self, k, V, v, H):
|
||||
v, h = self._gram_schmidt(V, v)
|
||||
|
||||
eps_v = self.eps(v[(0,) * v.ndim])
|
||||
_, v_norm_0 = self.safe_normalize(v)
|
||||
tol = eps_v * v_norm_0
|
||||
unit_v, v_norm_1 = self.safe_normalize(v, tol)
|
||||
V[..., k + 1] = unit_v
|
||||
|
||||
h[k + 1] = v_norm_1
|
||||
H[k, :] = h
|
||||
return V, H
|
||||
|
||||
def _gram_schmidt(self, Q, q):
|
||||
"""
|
||||
do Gram–Schmidt process to normalize vector v
|
||||
"""
|
||||
# transpose is not support float64 yet,
|
||||
# so the following code is the same as h = mnp.dot(Q.T, q)
|
||||
h = self.matmul_t(Q, q.reshape((q.shape[0], 1))).flatten()
|
||||
Qh = mnp.dot(Q, h)
|
||||
q = q - Qh
|
||||
|
||||
return q, h
|
||||
def arnoldi_iteration(k, A, M, V, H):
|
||||
""" Performs a single (the k'th) step of the Arnoldi process."""
|
||||
v_ = V[..., k]
|
||||
v = M(A(v_))
|
||||
v, h = gram_schmidt(V, v)
|
||||
eps_v = _eps(v)
|
||||
_, v_norm_0 = _safe_normalize(v)
|
||||
tol = eps_v * v_norm_0
|
||||
unit_v, v_norm_1 = _safe_normalize(v, tol)
|
||||
V[..., k + 1] = unit_v
|
||||
h[k + 1] = v_norm_1
|
||||
H[k, :] = h
|
||||
breakdown = v_norm_1 == 0
|
||||
return V, H, breakdown
|
||||
|
||||
|
||||
@ms_function
|
||||
|
@ -110,10 +102,55 @@ class GivensRotation(nn.Cell):
|
|||
return R_row, givens
|
||||
|
||||
|
||||
kth_arnoldi_iteration = ArnoldiIteration()
|
||||
givens_rotation = GivensRotation()
|
||||
|
||||
|
||||
class BatchedGmres(nn.Cell):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace
|
||||
This implementation solves a dense linear problem instead of building
|
||||
a QR factorization during the Arnoldi process.
|
||||
"""
|
||||
|
||||
def __init__(self, A, M):
|
||||
super(BatchedGmres, self).__init__()
|
||||
self.A = A
|
||||
self.M = M
|
||||
|
||||
def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None):
|
||||
A = _normalize_matvec(self.A)
|
||||
M = _normalize_matvec(self.M)
|
||||
dtype = b.dtype
|
||||
_, b_norm = _safe_normalize(b)
|
||||
atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype)
|
||||
residual = M(b - A(x0))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
k = _INT_ZERO
|
||||
x = x0
|
||||
while k < maxiter and residual_norm > atol:
|
||||
pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),)
|
||||
V = mnp.pad(unit_residual[..., None], pad_width=pad_width)
|
||||
H = mnp.eye(restart, restart + 1, dtype=dtype)
|
||||
k_iter = _INT_ZERO
|
||||
breakdown = _to_tensor(False)
|
||||
while k_iter < restart and mnp.logical_not(breakdown):
|
||||
V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H)
|
||||
k_iter += 1
|
||||
beta_vec = mnp.zeros((restart + 1,), dtype=dtype)
|
||||
beta_vec[0] = residual_norm
|
||||
a2 = mnp.dot(H, H.T)
|
||||
b2 = mnp.dot(H, beta_vec)
|
||||
c, lower = cho_factor(a2, lower=False)
|
||||
factor = (c, lower)
|
||||
y = cho_solve(factor, b2)
|
||||
dx = mnp.dot(V[..., :-1], y)
|
||||
x = x + dx
|
||||
residual = b - A(x)
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
k += 1
|
||||
return x
|
||||
|
||||
|
||||
def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func):
|
||||
"""
|
||||
Single iteration for Gmres Algorithm with restart
|
||||
|
@ -130,8 +167,7 @@ def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func):
|
|||
k = 0
|
||||
err = r_norm
|
||||
while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)):
|
||||
v = M_mat_func(A_mat_func(V[:, k]))
|
||||
V, H = kth_arnoldi_iteration(k, V, v, R)
|
||||
V, H, _ = arnoldi_iteration(k, A_mat_func, M_mat_func, V, R)
|
||||
R[k, :], givens = givens_rotation(H[k, :], givens, k)
|
||||
beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1])
|
||||
err = mnp.absolute(beta_vec[k + 1])
|
||||
|
@ -186,7 +222,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
|||
Krylov space starting from the solution found at the last iteration. If GMRES
|
||||
halts or is very slow, decreasing this parameter may help.
|
||||
Default is infinite.
|
||||
M (Tensor): Preconditioner for A. The preconditioner should approximate the
|
||||
M (Tensor or function): Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
|
@ -206,67 +242,66 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
|||
>>> import numpy as onp
|
||||
>>> from mindspore.common import Tensor
|
||||
>>> from mindspore.numpy as mnp
|
||||
>>> from mindspore.scipy.sparse import csc_matrix
|
||||
>>> from mindspore.scipy.sparse.linalg import gmres
|
||||
>>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=np.float32)
|
||||
>>> b = Tensor(onp.array([2, 4, -1], dtype=np.float32))
|
||||
>>> A = Tensor(mnp.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=mnp.float32))
|
||||
>>> b = Tensor(onp.array([2, 4, -1], dtype=mnp.float32))
|
||||
>>> x, exitCode = gmres(A, b)
|
||||
>>> print(exitCode) # 0 indicates successful convergence
|
||||
0
|
||||
>>> np.allclose(A.matvec(x).asnumpy(), b.asnumpy())
|
||||
>>> onp.allclose(mnp.dot(A,x).asnumpy(), b.asnumpy())
|
||||
True
|
||||
"""
|
||||
|
||||
if x0 is None:
|
||||
x0 = mnp.zeros_like(b)
|
||||
|
||||
if M is None:
|
||||
def M_mat_func(x):
|
||||
return x
|
||||
elif not callable(M):
|
||||
def M_mat_func(x):
|
||||
return mnp.dot(M, x)
|
||||
else:
|
||||
M_mat_func = M
|
||||
|
||||
if not callable(A):
|
||||
def A_mat_func(x):
|
||||
return mnp.dot(A, x)
|
||||
else:
|
||||
A_mat_func = A
|
||||
|
||||
size = b.size
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = 10 * size # copied from scipy
|
||||
restart = min(restart, size)
|
||||
|
||||
_, b_norm = _safe_normalize(b)
|
||||
atol = mnp.maximum(tol * b_norm, atol)
|
||||
|
||||
Mb = M_mat_func(b)
|
||||
_, Mb_norm = _safe_normalize(Mb)
|
||||
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
|
||||
|
||||
if restart > size:
|
||||
restart = size
|
||||
x = x0
|
||||
if solve_method == 'incremental':
|
||||
if M is None:
|
||||
def M_mat_func(x):
|
||||
return x
|
||||
elif not callable(M):
|
||||
def M_mat_func(x):
|
||||
return mnp.dot(M, x)
|
||||
else:
|
||||
M_mat_func = M
|
||||
|
||||
if not callable(A):
|
||||
def A_mat_func(x):
|
||||
return mnp.dot(A, x)
|
||||
else:
|
||||
A_mat_func = A
|
||||
|
||||
_, b_norm = _safe_normalize(b)
|
||||
atol = mnp.maximum(tol * b_norm, atol)
|
||||
|
||||
Mb = M_mat_func(b)
|
||||
_, Mb_norm = _safe_normalize(Mb)
|
||||
ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm)
|
||||
# iterative gmres
|
||||
r = M_mat_func(b - A_mat_func(x0))
|
||||
r, r_norm = _safe_normalize(r)
|
||||
x = x0
|
||||
k = 0
|
||||
while k < maxiter and r_norm > atol:
|
||||
x, r, r_norm = gmres_iter(
|
||||
A_mat_func, b, x, r, r_norm, ptol, restart, M_mat_func)
|
||||
k += 1
|
||||
|
||||
_, x_norm = _safe_normalize(x)
|
||||
info = mnp.where(mnp.isnan(x_norm), -1, 0)
|
||||
elif solve_method == 'batched':
|
||||
raise NotImplementedError("batched method not implemented yet")
|
||||
if M is None:
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
M = identity
|
||||
x = BatchedGmres(A, M)(b, x, tol, atol, restart, maxiter)
|
||||
else:
|
||||
raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}."
|
||||
.format(solve_method))
|
||||
|
||||
_, x_norm = _safe_normalize(x)
|
||||
info = mnp.where(mnp.isnan(x_norm), _INT_NEG_ONE, _INT_ZERO)
|
||||
return x, info
|
||||
|
||||
|
||||
|
@ -344,7 +379,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None) -> (Tensor, N
|
|||
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
|
||||
maxiter (int): Maximum number of iterations. Iteration will stop after maxiter
|
||||
steps even if the specified tolerance has not been achieved.
|
||||
M (Tensor): Preconditioner for A. The preconditioner should approximate the
|
||||
M (Tensor or function): Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""internal utility functions"""
|
||||
import numpy as onp
|
||||
from .. import nn, ops
|
||||
from ..numpy import where, isnan, zeros_like, dot
|
||||
from ..numpy import where, zeros_like, dot, greater
|
||||
from ..ops import functional as F
|
||||
from ..common import Tensor
|
||||
from ..common import dtype as mstype
|
||||
|
@ -25,6 +25,7 @@ from ..ops.primitive import constexpr
|
|||
from .._c_expression import typing
|
||||
|
||||
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
||||
_eps_net = ops.Eps()
|
||||
|
||||
|
||||
def _convert_64_to_32(tensor):
|
||||
|
@ -69,13 +70,16 @@ def _to_scalar(arr):
|
|||
raise ValueError("{} are not supported.".format(type(arr)))
|
||||
|
||||
|
||||
def _eps(x):
|
||||
return _eps_net(x[(0,) * x.ndim])
|
||||
|
||||
|
||||
class _SafeNormalize(nn.Cell):
|
||||
"""Normalize method that cast very small results to zero."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LineSearch."""
|
||||
super(_SafeNormalize, self).__init__()
|
||||
self.eps = ops.Eps()
|
||||
|
||||
def construct(self, x, threshold=None):
|
||||
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
|
||||
|
@ -83,11 +87,13 @@ class _SafeNormalize(nn.Cell):
|
|||
if threshold is None:
|
||||
if x.dtype in mstype.float_type:
|
||||
# pick the first element of x to get the eps
|
||||
threshold = self.eps(x[(0,) * x.ndim])
|
||||
threshold = _eps(x)
|
||||
else:
|
||||
threshold = 0
|
||||
normalized_x = where(norm > threshold, x / norm, zeros_like(x))
|
||||
normalized_x = where(isnan(normalized_x), 0, normalized_x)
|
||||
use_norm = greater(norm, threshold)
|
||||
x_norm = x / norm
|
||||
normalized_x = where(use_norm, x_norm, zeros_like(x))
|
||||
norm = where(use_norm, norm, zeros_like(norm))
|
||||
return normalized_x, norm
|
||||
|
||||
|
||||
|
@ -95,6 +101,9 @@ _safe_normalize = _SafeNormalize()
|
|||
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
_INT_ONE = _to_tensor(1)
|
||||
_INT_NEG_ONE = _to_tensor(-1)
|
||||
_FLOAT_ONE = _to_tensor(1.0)
|
||||
_FLOAT_TWO = _to_tensor(2.0, dtype=float)
|
||||
_BOOL_TRUE = _to_tensor(True)
|
||||
_BOOL_FALSE = _to_tensor(False)
|
||||
|
||||
|
|
|
@ -70,9 +70,9 @@ def test_bfgs(dtype, func_x0):
|
|||
x0 = x0.astype(dtype)
|
||||
x0_tensor = Tensor(x0)
|
||||
ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
|
||||
options=dict(maxiter=None, gtol=1e-6)).x
|
||||
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS').x
|
||||
match_array(ms_res.asnumpy(), scipy_res, error=5)
|
||||
options=dict(maxiter=None, gtol=1e-6))
|
||||
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS')
|
||||
match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
|
||||
import pytest
|
||||
import numpy as onp
|
||||
import scipy as osp
|
||||
from scipy.sparse.linalg import cg as osp_cg
|
||||
|
||||
import mindspore.scipy as msp
|
||||
from mindspore import context
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.scipy.sparse.linalg import cg as msp_cg
|
||||
|
||||
from .utils import create_sym_pos_matrix
|
||||
from .utils import create_sym_pos_matrix, create_full_rank_matrix
|
||||
|
||||
onp.random.seed(0)
|
||||
|
||||
|
@ -109,3 +109,53 @@ def test_cg_against_numpy(dtype, shape):
|
|||
kw = {"atol": 1e-5, "rtol": 1e-5}
|
||||
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [4, 5, 6])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [1, 2])
|
||||
def test_pynative_batched_gmres_against_scipy(n, dtype, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for gmres
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
shape = (n, n)
|
||||
a = create_full_rank_matrix(shape, dtype)
|
||||
b = onp.random.rand(n).astype(dtype=dtype)
|
||||
M = _fetch_preconditioner(preconditioner, a)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
M = Tensor(M) if M is not None else M
|
||||
|
||||
osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6)
|
||||
|
||||
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6,
|
||||
solve_method='batched')
|
||||
assert onp.allclose(msp_x.asnumpy(), osp_x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('n', [5, 6])
|
||||
@pytest.mark.parametrize('dtype', [onp.float64])
|
||||
def test_graph_batched_gmres_against_scipy(n, dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for gmres
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
shape = (n, n)
|
||||
a = create_full_rank_matrix(shape, dtype)
|
||||
b = onp.random.rand(n).astype(dtype=dtype)
|
||||
tensor_a = Tensor(a)
|
||||
tensor_b = Tensor(b)
|
||||
osp_x, _ = osp.sparse.linalg.gmres(a, b, atol=0.0)
|
||||
msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, atol=0.0, solve_method='batched')
|
||||
assert onp.allclose(msp_x.asnumpy(), osp_x)
|
||||
|
|
|
@ -30,7 +30,7 @@ def to_tensor(obj, dtype=None):
|
|||
return res
|
||||
|
||||
|
||||
def match_array(actual, expected, error=0):
|
||||
def match_array(actual, expected, error=0, err_msg=''):
|
||||
if isinstance(actual, int):
|
||||
actual = onp.asarray(actual)
|
||||
|
||||
|
@ -38,9 +38,9 @@ def match_array(actual, expected, error=0):
|
|||
expected = onp.asarray(expected)
|
||||
|
||||
if error > 0:
|
||||
onp.testing.assert_almost_equal(actual, expected, decimal=error)
|
||||
onp.testing.assert_almost_equal(actual, expected, decimal=error, err_msg=err_msg)
|
||||
else:
|
||||
onp.testing.assert_equal(actual, expected)
|
||||
onp.testing.assert_equal(actual, expected, err_msg=err_msg)
|
||||
|
||||
|
||||
def create_full_rank_matrix(shape, dtype):
|
||||
|
|
Loading…
Reference in New Issue