From c3f823cb3e5b3bfa5424312a180d1515a2fadfa4 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Tue, 16 Nov 2021 16:04:21 +0800 Subject: [PATCH] scipy gmres batched interface --- mindspore/scipy/__init__.py | 8 +- mindspore/scipy/sparse/__init__.py | 3 +- mindspore/scipy/sparse/linalg.py | 185 ++++++++++++++---------- mindspore/scipy/utils.py | 19 ++- tests/st/scipy_st/test_optimize.py | 6 +- tests/st/scipy_st/test_sparse_linalg.py | 56 ++++++- tests/st/scipy_st/utils.py | 6 +- 7 files changed, 190 insertions(+), 93 deletions(-) diff --git a/mindspore/scipy/__init__.py b/mindspore/scipy/__init__.py index 26adff44d55..eddb8a51cfd 100644 --- a/mindspore/scipy/__init__.py +++ b/mindspore/scipy/__init__.py @@ -14,12 +14,14 @@ # ============================================================================ """Scipy-like interfaces in mindspore.""" -from . import optimize, linalg -from .optimize import * -from .linalg import * +from . import optimize, sparse, linalg +from .optimize import minimize, line_search +from .sparse import cg, gmres +from .linalg import block_diag, solve_triangular, inv, cho_factor, cholesky, cho_solve __all__ = [] __all__.extend(optimize.__all__) +__all__.extend(sparse.__all__) __all__.extend(linalg.__all__) __all__.sort() diff --git a/mindspore/scipy/sparse/__init__.py b/mindspore/scipy/sparse/__init__.py index 7ae2f200abe..aafef88d3d3 100644 --- a/mindspore/scipy/sparse/__init__.py +++ b/mindspore/scipy/sparse/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Sparse linear algebra submodule""" +"""Sparse submodule""" +from . import linalg from .linalg import cg, gmres __all__ = ["cg", "gmres"] diff --git a/mindspore/scipy/sparse/linalg.py b/mindspore/scipy/sparse/linalg.py index e0f5189dd3f..355f1ac9ccb 100644 --- a/mindspore/scipy/sparse/linalg.py +++ b/mindspore/scipy/sparse/linalg.py @@ -17,44 +17,36 @@ from ... import nn, Tensor, ops, ms_function from ... import numpy as mnp from ...ops import functional as F from ..linalg import solve_triangular - -from ..utils import _INT_ZERO, _normalize_matvec, _INT_ONE, _safe_normalize, _SafeNormalize +from ..linalg import cho_factor, cho_solve +from ..utils import _INT_ZERO, _INT_ONE, _INT_NEG_ONE, _normalize_matvec, _to_tensor, _safe_normalize, _eps -class ArnoldiIteration(nn.Cell): - """ do the Arnoldi iteration""" +def gram_schmidt(Q, q): + """ + do Gram–Schmidt process to normalize vector v + """ + # transpose is not support float64 yet, + # so the following code is the same as h = mnp.dot(Q.T, q) + h = ops.MatMul(True, False)(Q, q.reshape((q.shape[0], 1))).flatten() + Qh = mnp.dot(Q, h) + q = q - Qh + return q, h - def __init__(self): - super(ArnoldiIteration, self).__init__() - self.safe_normalize = _SafeNormalize() - self.eps = ops.Eps() - self.sqrt_2 = F.pows(Tensor(2.0), 1/2.0) - self.matmul_t = ops.MatMul(True, False) - def construct(self, k, V, v, H): - v, h = self._gram_schmidt(V, v) - - eps_v = self.eps(v[(0,) * v.ndim]) - _, v_norm_0 = self.safe_normalize(v) - tol = eps_v * v_norm_0 - unit_v, v_norm_1 = self.safe_normalize(v, tol) - V[..., k + 1] = unit_v - - h[k + 1] = v_norm_1 - H[k, :] = h - return V, H - - def _gram_schmidt(self, Q, q): - """ - do Gram–Schmidt process to normalize vector v - """ - # transpose is not support float64 yet, - # so the following code is the same as h = mnp.dot(Q.T, q) - h = self.matmul_t(Q, q.reshape((q.shape[0], 1))).flatten() - Qh = mnp.dot(Q, h) - q = q - Qh - - return q, h +def arnoldi_iteration(k, A, M, V, H): + """ Performs a single (the k'th) step of the Arnoldi process.""" + v_ = V[..., k] + v = M(A(v_)) + v, h = gram_schmidt(V, v) + eps_v = _eps(v) + _, v_norm_0 = _safe_normalize(v) + tol = eps_v * v_norm_0 + unit_v, v_norm_1 = _safe_normalize(v, tol) + V[..., k + 1] = unit_v + h[k + 1] = v_norm_1 + H[k, :] = h + breakdown = v_norm_1 == 0 + return V, H, breakdown @ms_function @@ -110,10 +102,55 @@ class GivensRotation(nn.Cell): return R_row, givens -kth_arnoldi_iteration = ArnoldiIteration() givens_rotation = GivensRotation() +class BatchedGmres(nn.Cell): + """ + Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace + This implementation solves a dense linear problem instead of building + a QR factorization during the Arnoldi process. + """ + + def __init__(self, A, M): + super(BatchedGmres, self).__init__() + self.A = A + self.M = M + + def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None): + A = _normalize_matvec(self.A) + M = _normalize_matvec(self.M) + dtype = b.dtype + _, b_norm = _safe_normalize(b) + atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype) + residual = M(b - A(x0)) + unit_residual, residual_norm = _safe_normalize(residual) + k = _INT_ZERO + x = x0 + while k < maxiter and residual_norm > atol: + pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),) + V = mnp.pad(unit_residual[..., None], pad_width=pad_width) + H = mnp.eye(restart, restart + 1, dtype=dtype) + k_iter = _INT_ZERO + breakdown = _to_tensor(False) + while k_iter < restart and mnp.logical_not(breakdown): + V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H) + k_iter += 1 + beta_vec = mnp.zeros((restart + 1,), dtype=dtype) + beta_vec[0] = residual_norm + a2 = mnp.dot(H, H.T) + b2 = mnp.dot(H, beta_vec) + c, lower = cho_factor(a2, lower=False) + factor = (c, lower) + y = cho_solve(factor, b2) + dx = mnp.dot(V[..., :-1], y) + x = x + dx + residual = b - A(x) + unit_residual, residual_norm = _safe_normalize(residual) + k += 1 + return x + + def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func): """ Single iteration for Gmres Algorithm with restart @@ -130,8 +167,7 @@ def gmres_iter(A_mat_func, b, x0, r, r_norm, ptol, restart, M_mat_func): k = 0 err = r_norm while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)): - v = M_mat_func(A_mat_func(V[:, k])) - V, H = kth_arnoldi_iteration(k, V, v, R) + V, H, _ = arnoldi_iteration(k, A_mat_func, M_mat_func, V, R) R[k, :], givens = givens_rotation(H[k, :], givens, k) beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1]) err = mnp.absolute(beta_vec[k + 1]) @@ -186,7 +222,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, Krylov space starting from the solution found at the last iteration. If GMRES halts or is very slow, decreasing this parameter may help. Default is infinite. - M (Tensor): Preconditioner for A. The preconditioner should approximate the + M (Tensor or function): Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. @@ -206,67 +242,66 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None, >>> import numpy as onp >>> from mindspore.common import Tensor >>> from mindspore.numpy as mnp - >>> from mindspore.scipy.sparse import csc_matrix >>> from mindspore.scipy.sparse.linalg import gmres - >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=np.float32) - >>> b = Tensor(onp.array([2, 4, -1], dtype=np.float32)) + >>> A = Tensor(mnp.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=mnp.float32)) + >>> b = Tensor(onp.array([2, 4, -1], dtype=mnp.float32)) >>> x, exitCode = gmres(A, b) >>> print(exitCode) # 0 indicates successful convergence 0 - >>> np.allclose(A.matvec(x).asnumpy(), b.asnumpy()) + >>> onp.allclose(mnp.dot(A,x).asnumpy(), b.asnumpy()) True """ if x0 is None: x0 = mnp.zeros_like(b) - - if M is None: - def M_mat_func(x): - return x - elif not callable(M): - def M_mat_func(x): - return mnp.dot(M, x) - else: - M_mat_func = M - - if not callable(A): - def A_mat_func(x): - return mnp.dot(A, x) - else: - A_mat_func = A - size = b.size - if maxiter is None: maxiter = 10 * size # copied from scipy - restart = min(restart, size) - - _, b_norm = _safe_normalize(b) - atol = mnp.maximum(tol * b_norm, atol) - - Mb = M_mat_func(b) - _, Mb_norm = _safe_normalize(Mb) - ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm) - + if restart > size: + restart = size + x = x0 if solve_method == 'incremental': + if M is None: + def M_mat_func(x): + return x + elif not callable(M): + def M_mat_func(x): + return mnp.dot(M, x) + else: + M_mat_func = M + + if not callable(A): + def A_mat_func(x): + return mnp.dot(A, x) + else: + A_mat_func = A + + _, b_norm = _safe_normalize(b) + atol = mnp.maximum(tol * b_norm, atol) + + Mb = M_mat_func(b) + _, Mb_norm = _safe_normalize(Mb) + ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm) # iterative gmres r = M_mat_func(b - A_mat_func(x0)) r, r_norm = _safe_normalize(r) - x = x0 k = 0 while k < maxiter and r_norm > atol: x, r, r_norm = gmres_iter( A_mat_func, b, x, r, r_norm, ptol, restart, M_mat_func) k += 1 - - _, x_norm = _safe_normalize(x) - info = mnp.where(mnp.isnan(x_norm), -1, 0) elif solve_method == 'batched': - raise NotImplementedError("batched method not implemented yet") + if M is None: + def identity(x): + return x + + M = identity + x = BatchedGmres(A, M)(b, x, tol, atol, restart, maxiter) else: raise ValueError("solve_method should be in ('incremental' or 'batched'), but got {}." .format(solve_method)) - + _, x_norm = _safe_normalize(x) + info = mnp.where(mnp.isnan(x_norm), _INT_NEG_ONE, _INT_ZERO) return x, info @@ -344,7 +379,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None) -> (Tensor, N differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. maxiter (int): Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved. - M (Tensor): Preconditioner for A. The preconditioner should approximate the + M (Tensor or function): Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. diff --git a/mindspore/scipy/utils.py b/mindspore/scipy/utils.py index 0fd12fbf2cb..7795e8da8af 100644 --- a/mindspore/scipy/utils.py +++ b/mindspore/scipy/utils.py @@ -15,7 +15,7 @@ """internal utility functions""" import numpy as onp from .. import nn, ops -from ..numpy import where, isnan, zeros_like, dot +from ..numpy import where, zeros_like, dot, greater from ..ops import functional as F from ..common import Tensor from ..common import dtype as mstype @@ -25,6 +25,7 @@ from ..ops.primitive import constexpr from .._c_expression import typing grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) +_eps_net = ops.Eps() def _convert_64_to_32(tensor): @@ -69,13 +70,16 @@ def _to_scalar(arr): raise ValueError("{} are not supported.".format(type(arr))) +def _eps(x): + return _eps_net(x[(0,) * x.ndim]) + + class _SafeNormalize(nn.Cell): """Normalize method that cast very small results to zero.""" def __init__(self): """Initialize LineSearch.""" super(_SafeNormalize, self).__init__() - self.eps = ops.Eps() def construct(self, x, threshold=None): x_sum2 = F.reduce_sum(F.pows(x, 2.0)) @@ -83,11 +87,13 @@ class _SafeNormalize(nn.Cell): if threshold is None: if x.dtype in mstype.float_type: # pick the first element of x to get the eps - threshold = self.eps(x[(0,) * x.ndim]) + threshold = _eps(x) else: threshold = 0 - normalized_x = where(norm > threshold, x / norm, zeros_like(x)) - normalized_x = where(isnan(normalized_x), 0, normalized_x) + use_norm = greater(norm, threshold) + x_norm = x / norm + normalized_x = where(use_norm, x_norm, zeros_like(x)) + norm = where(use_norm, norm, zeros_like(norm)) return normalized_x, norm @@ -95,6 +101,9 @@ _safe_normalize = _SafeNormalize() _INT_ZERO = _to_tensor(0) _INT_ONE = _to_tensor(1) +_INT_NEG_ONE = _to_tensor(-1) +_FLOAT_ONE = _to_tensor(1.0) +_FLOAT_TWO = _to_tensor(2.0, dtype=float) _BOOL_TRUE = _to_tensor(True) _BOOL_FALSE = _to_tensor(False) diff --git a/tests/st/scipy_st/test_optimize.py b/tests/st/scipy_st/test_optimize.py index e91797fd9d7..20934245a3e 100644 --- a/tests/st/scipy_st/test_optimize.py +++ b/tests/st/scipy_st/test_optimize.py @@ -70,9 +70,9 @@ def test_bfgs(dtype, func_x0): x0 = x0.astype(dtype) x0_tensor = Tensor(x0) ms_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', - options=dict(maxiter=None, gtol=1e-6)).x - scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS').x - match_array(ms_res.asnumpy(), scipy_res, error=5) + options=dict(maxiter=None, gtol=1e-6)) + scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS') + match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res)) @pytest.mark.level0 diff --git a/tests/st/scipy_st/test_sparse_linalg.py b/tests/st/scipy_st/test_sparse_linalg.py index 8becc463829..6478029653a 100644 --- a/tests/st/scipy_st/test_sparse_linalg.py +++ b/tests/st/scipy_st/test_sparse_linalg.py @@ -16,13 +16,13 @@ import pytest import numpy as onp +import scipy as osp from scipy.sparse.linalg import cg as osp_cg - +import mindspore.scipy as msp from mindspore import context from mindspore.common import Tensor from mindspore.scipy.sparse.linalg import cg as msp_cg - -from .utils import create_sym_pos_matrix +from .utils import create_sym_pos_matrix, create_full_rank_matrix onp.random.seed(0) @@ -109,3 +109,53 @@ def test_cg_against_numpy(dtype, shape): kw = {"atol": 1e-5, "rtol": 1e-5} onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw) onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('n', [4, 5, 6]) +@pytest.mark.parametrize('dtype', [onp.float64]) +@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random']) +@pytest.mark.parametrize('maxiter', [1, 2]) +def test_pynative_batched_gmres_against_scipy(n, dtype, preconditioner, maxiter): + """ + Feature: ALL TO ALL + Description: test cases for gmres + Expectation: the result match scipy + """ + shape = (n, n) + a = create_full_rank_matrix(shape, dtype) + b = onp.random.rand(n).astype(dtype=dtype) + M = _fetch_preconditioner(preconditioner, a) + tensor_a = Tensor(a) + tensor_b = Tensor(b) + M = Tensor(M) if M is not None else M + + osp_x, _ = osp.sparse.linalg.gmres(a, b, maxiter=maxiter, atol=1e-6) + + msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, maxiter=maxiter, M=M, atol=1e-6, + solve_method='batched') + assert onp.allclose(msp_x.asnumpy(), osp_x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('n', [5, 6]) +@pytest.mark.parametrize('dtype', [onp.float64]) +def test_graph_batched_gmres_against_scipy(n, dtype): + """ + Feature: ALL TO ALL + Description: test cases for gmres + Expectation: the result match scipy + """ + context.set_context(mode=context.GRAPH_MODE) + shape = (n, n) + a = create_full_rank_matrix(shape, dtype) + b = onp.random.rand(n).astype(dtype=dtype) + tensor_a = Tensor(a) + tensor_b = Tensor(b) + osp_x, _ = osp.sparse.linalg.gmres(a, b, atol=0.0) + msp_x, _ = msp.sparse.linalg.gmres(tensor_a, tensor_b, atol=0.0, solve_method='batched') + assert onp.allclose(msp_x.asnumpy(), osp_x) diff --git a/tests/st/scipy_st/utils.py b/tests/st/scipy_st/utils.py index acca59f2b9e..b4bc18e65bc 100644 --- a/tests/st/scipy_st/utils.py +++ b/tests/st/scipy_st/utils.py @@ -30,7 +30,7 @@ def to_tensor(obj, dtype=None): return res -def match_array(actual, expected, error=0): +def match_array(actual, expected, error=0, err_msg=''): if isinstance(actual, int): actual = onp.asarray(actual) @@ -38,9 +38,9 @@ def match_array(actual, expected, error=0): expected = onp.asarray(expected) if error > 0: - onp.testing.assert_almost_equal(actual, expected, decimal=error) + onp.testing.assert_almost_equal(actual, expected, decimal=error, err_msg=err_msg) else: - onp.testing.assert_equal(actual, expected) + onp.testing.assert_equal(actual, expected, err_msg=err_msg) def create_full_rank_matrix(shape, dtype):