diff --git a/mindspore/python/mindspore/scipy/sparse/linalg.py b/mindspore/python/mindspore/scipy/sparse/linalg.py index 06c6c91b8cf..0dc8a127929 100644 --- a/mindspore/python/mindspore/scipy/sparse/linalg.py +++ b/mindspore/python/mindspore/scipy/sparse/linalg.py @@ -72,128 +72,157 @@ def _high_precision_cho_solve(a, b, data_type=mstype.float64): return y.astype(data_type) -class BatchedGmres(nn.Cell): +def _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M): """ - Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace - This implementation solves a dense linear problem instead of building - a QR factorization during the Arnoldi process. + batched gmres: solve the least squares problem from scratch at the end of each GMRES iteration. + It does not allow for early termination, but has much less overhead on GPUs. """ - - def __init__(self, A, M): - super(BatchedGmres, self).__init__() - self.A = A - self.M = M - - def construct(self, b, x0=None, tol=1e-5, atol=0.0, restart=20, maxiter=None): - # Constant tensor which avoids loop unrolling - _INT_ZERO = _to_tensor(0) - - A = _normalize_matvec(self.A) - M = _normalize_matvec(self.M) - dtype = b.dtype - _, b_norm = _safe_normalize(b) - atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype) - residual = M(b - A(x0)) + # Constant tensor which avoids loop unrolling + _INT_ZERO = _to_tensor(0) + dtype = b.dtype + _, b_norm = _safe_normalize(b) + atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype) + residual = M(b - A(x0)) + unit_residual, residual_norm = _safe_normalize(residual) + k = _INT_ZERO + x = x0 + while k < maxiter and residual_norm > atol: + pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),) + V = mnp.pad(unit_residual[..., None], pad_width=pad_width) + H = mnp.eye(restart, restart + 1, dtype=dtype) + k_iter = _INT_ZERO + breakdown = _to_tensor(False) + while k_iter < restart and mnp.logical_not(breakdown): + V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H) + k_iter += 1 + beta_vec = mnp.zeros((restart + 1,), dtype=dtype) + beta_vec[0] = residual_norm + y = _high_precision_cho_solve(H, beta_vec, data_type=dtype) + dx = mnp.dot(V[..., :-1], y) + x = x + dx + residual = M(b - A(x)) unit_residual, residual_norm = _safe_normalize(residual) + k += 1 + return x, F.select(residual_norm > atol, k, _INT_ZERO) + + +def _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M): + """ + incremental gmres: builds a QR decomposition for the Krylov subspace incrementally during + the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of + the residual norm that allows for early termination within a single "restart". + """ + _INT_ZERO = _to_tensor(0) + _, b_norm = _safe_normalize(b) + atol = mnp.maximum(tol * b_norm, atol) + + Mb = M(b) + _, Mb_norm = _safe_normalize(Mb) + ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm) + + r = M(b - A(x0)) + r, r_norm = _safe_normalize(r) + + iters = _INT_ZERO + while iters < maxiter and r_norm > atol: + V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),)) + dtype = mnp.result_type(b) + # Use eye() to avoid constructing a singular matrix in case of early + # Termination + R = mnp.eye(restart, restart + 1, dtype=dtype) + givens = mnp.zeros((restart, 2), dtype=dtype) + beta_vec = mnp.zeros((restart + 1), dtype=dtype) + beta_vec[0] = r_norm + k = _INT_ZERO - x = x0 - while k < maxiter and residual_norm > atol: - pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),) - V = mnp.pad(unit_residual[..., None], pad_width=pad_width) - H = mnp.eye(restart, restart + 1, dtype=dtype) - k_iter = _INT_ZERO - breakdown = _to_tensor(False) - while k_iter < restart and mnp.logical_not(breakdown): - V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H) - k_iter += 1 - beta_vec = mnp.zeros((restart + 1,), dtype=dtype) - beta_vec[0] = residual_norm - y = _high_precision_cho_solve(H, beta_vec, data_type=dtype) - dx = mnp.dot(V[..., :-1], y) - x = x + dx - residual = M(b - A(x)) - unit_residual, residual_norm = _safe_normalize(residual) + err = r_norm + while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)): + V, R, _ = arnoldi_iteration(k, A, M, V, R) + # Givens rotation + row_k = R[k, :].copy() + i = _INT_ZERO + while i < k: + row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1]) + i += 1 + + if row_k[k + 1] == 0: + givens[k, 0] = 1 + givens[k, 1] = 0 + else: + increase = mnp.absolute(row_k[k]) < mnp.absolute(row_k[k + 1]) + t = mnp.where(increase, -row_k[k] / row_k[k + 1], -row_k[k + 1] / row_k[k]) + r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2) + givens[k, 0] = mnp.where(increase, r * t, r) + givens[k, 1] = mnp.where(increase, r, r * t) + + R[k, :] = rotate_vectors(row_k, k, givens[k, 0], givens[k, 1]) + beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1]) + err = mnp.absolute(beta_vec[k + 1]) k += 1 - return x, F.select(residual_norm > atol, k, _INT_ZERO) + y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True) + dx = mnp.dot(V[:, :-1], y) + + x = x0 + dx + r = M(b - A(x)) + r, r_norm = _safe_normalize(r) + x0 = x + iters += 1 + return x0, F.select(r_norm > atol, iters, _INT_ZERO) -class IterativeGmres(nn.Cell): +class GMRES(nn.Cell): """ - Implements a iterative GMRES. While building the ``restart``-dimensional - Krylov subspace iteratively using Givens Rotation method, the algorithm - constructs a Triangular matrix R which could be more easily solved. + Given given A and b, GMRES solves the linear system: + + .. math:: + A x = b """ - def __init__(self, A, M): - super(IterativeGmres, self).__init__() + def __init__(self, A, M, solve_method): + super(GMRES, self).__init__() self.A = A self.M = M + self.solve_method = solve_method def construct(self, b, x0, tol, atol, restart, maxiter): # Constant tensor which avoids loop unrolling - _INT_ZERO = _to_tensor(0) - A = _normalize_matvec(self.A) M = _normalize_matvec(self.M) + x = x0 + info = _to_tensor(0) + if self.solve_method == 'batched': + x, info = _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M) + elif self.solve_method == "incremental": + x, info = _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M) + else: + _raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", self.solve_method, + ".") + return x, info - _, b_norm = _safe_normalize(b) - atol = mnp.maximum(tol * b_norm, atol) - Mb = M(b) - _, Mb_norm = _safe_normalize(Mb) - ptol = Mb_norm * mnp.minimum(1.0, atol / b_norm) +class GMRESV2(nn.Cell): + """ + This is a new version of GMRES, which contains all parameters in a graph. + """ - r = M(b - A(x0)) - r, r_norm = _safe_normalize(r) + def __init__(self, solve_method): + super(GMRESV2, self).__init__() + self.solve_method = solve_method - iters = _INT_ZERO - while iters < maxiter and r_norm > atol: - V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),)) - dtype = mnp.result_type(b) - # use eye() to avoid constructing a singular matrix in case of early - # termination - R = mnp.eye(restart, restart + 1, dtype=dtype) - givens = mnp.zeros((restart, 2), dtype=dtype) - beta_vec = mnp.zeros((restart + 1), dtype=dtype) - beta_vec[0] = r_norm - - k = _INT_ZERO - err = r_norm - while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)): - V, R, _ = arnoldi_iteration(k, A, M, V, R) - # givens rotation - row_k = R[k, :].copy() - i = _INT_ZERO - while i < k: - row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1]) - i += 1 - - if row_k[k + 1] == 0: - givens[k, 0] = 1 - givens[k, 1] = 0 - else: - increase = mnp.absolute(row_k[k]) < mnp.absolute(row_k[k + 1]) - t = mnp.where(increase, -row_k[k] / row_k[k + 1], -row_k[k + 1] / row_k[k]) - r = 1 / F.sqrt(1 + mnp.absolute(t) ** 2) - givens[k, 0] = mnp.where(increase, r * t, r) - givens[k, 1] = mnp.where(increase, r, r * t) - - R[k, :] = rotate_vectors(row_k, k, givens[k, 0], givens[k, 1]) - beta_vec = rotate_vectors(beta_vec, k, givens[k, 0], givens[k, 1]) - err = mnp.absolute(beta_vec[k + 1]) - k += 1 - - y = solve_triangular(R[:, :-1], beta_vec[:-1], trans='T', lower=True) - dx = mnp.dot(V[:, :-1], y) - - x = x0 + dx - r = M(b - A(x)) - r, r_norm = _safe_normalize(r) - x0 = x - iters += 1 - - return x0, F.select(r_norm > atol, iters, _INT_ZERO) + def construct(self, A, b, x0, tol, atol, restart, maxiter, M): + A = _normalize_matvec(A) + M = _normalize_matvec(M) + x = x0 + info = _to_tensor(0) + if self.solve_method == 'batched': + x, info = _batch_gmres(A, x0, b, tol, atol, restart, maxiter, M) + elif self.solve_method == "incremental": + x, info = _incremental_gmres(A, x0, b, tol, atol, restart, maxiter, M) + else: + _raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", self.solve_method, + ".") + return x, info def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None, @@ -292,13 +321,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None, _value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo') if restart > size: restart = size - - if solve_method == 'incremental': - x, info = IterativeGmres(A, M)(b, x0, tol, atol, restart, maxiter) - elif solve_method == 'batched': - x, info = BatchedGmres(A, M)(b, x0, tol, atol, restart, maxiter) + if not is_within_graph(A): + x, info = GMRES(A, M, solve_method)(b, x0, tol, atol, restart, maxiter) else: - _raise_value_error("solve_method should be in ('incremental' or 'batched'), but got ", solve_method, ".") + x, info = GMRESV2(solve_method)(A, b, x0, tol, atol, restart, maxiter, M) return x, info diff --git a/tests/st/scipy_st/sparse/test_linalg.py b/tests/st/scipy_st/sparse/test_linalg.py index 745008eb21e..2975ebbe5d2 100644 --- a/tests/st/scipy_st/sparse/test_linalg.py +++ b/tests/st/scipy_st/sparse/test_linalg.py @@ -43,6 +43,12 @@ def _fetch_preconditioner(preconditioner, a): return M +def _is_valid_platform(tensor_type='Tensor'): + if tensor_type == "CSRTensor" and get_platform() != "linux": + return False + return True + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @@ -58,7 +64,7 @@ def test_cg_against_scipy(tensor_type, dtype, tol, shape, preconditioner, maxite Description: test cases for cg using function way in pynative/graph mode Expectation: the result match scipy """ - if tensor_type == "CSRTensor" and get_platform() != "linux": + if not _is_valid_platform(tensor_type): return onp.random.seed(0) a = create_sym_pos_matrix(shape, dtype) @@ -70,11 +76,11 @@ def test_cg_against_scipy(tensor_type, dtype, tol, shape, preconditioner, maxite b = Tensor(b) m = to_tensor((m, tensor_type)) if m is not None else m - # using PYNATIVE MODE + # Using PYNATIVE MODE context.set_context(mode=context.PYNATIVE_MODE) msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) - # using GRAPH MODE + # Using GRAPH MODE context.set_context(mode=context.GRAPH_MODE) msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) @@ -101,11 +107,11 @@ def test_cg_against_numpy(dtype, shape): b = onp.random.random(shape[:1]).astype(dtype) expected = onp.linalg.solve(a, b) - # using PYNATIVE MODE + # Using PYNATIVE MODE context.set_context(mode=context.PYNATIVE_MODE) actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b)) - # using GRAPH MODE + # Using GRAPH MODE context.set_context(mode=context.GRAPH_MODE) actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b)) @@ -146,11 +152,11 @@ def test_cg_against_scipy_graph(tensor_type, dtype, tol, shape, preconditioner, b = Tensor(b) m = to_tensor((m, tensor_type)) if m is not None else m - # using PYNATIVE MODE + # Using PYNATIVE MODE context.set_context(mode=context.PYNATIVE_MODE) msp_res_dyn = Net()(a, b, m, maxiter, tol) - # using GRAPH MODE + # Using GRAPH MODE context.set_context(mode=context.GRAPH_MODE) msp_res_sta = Net()(a, b, m, maxiter, tol) @@ -339,33 +345,39 @@ def test_gmres_against_scipy_level1(n, dtype, error, preconditioner, solve_metho @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @pytest.mark.parametrize('n', [3, 7]) -@pytest.mark.parametrize('dtype,error', [(onp.float64, 1e-5), (onp.float32, 1e-4)]) +@pytest.mark.parametrize('tensor_type, dtype, error', [('Tensor', onp.float64, 1e-5), ('Tensor', onp.float32, 1e-4), + ('CSRTensor', onp.float32, 1e-4)]) @pytest.mark.parametrize('restart', [1, 2]) @pytest.mark.parametrize('maxiter', [1, 2]) @pytest.mark.parametrize('preconditioner', ['identity', 'exact', 'random']) @pytest.mark.parametrize('solve_method', ['incremental', 'batched']) -def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner, solve_method): +def test_gmres_against_scipy(n, tensor_type, dtype, error, restart, maxiter, preconditioner, solve_method): """ Feature: ALL TO ALL Description: test cases for [N x N] X [N X 1] Expectation: the result match scipy """ + if not _is_valid_platform(tensor_type): + return onp.random.seed(0) a = create_full_rank_matrix((n, n), dtype) b = onp.random.rand(n).astype(dtype) x0 = onp.zeros_like(b).astype(dtype) - M = _fetch_preconditioner(preconditioner, a) + m = _fetch_preconditioner(preconditioner, a) tol = float(onp.finfo(dtype=dtype).eps) atol = tol if preconditioner == 'random': restart = n maxiter = None - scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol) + scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, atol=atol) # PyNative Mode context.set_context(mode=context.PYNATIVE_MODE) - M = Tensor(M) if M is not None else M - ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter, - M=M, atol=atol, solve_method=solve_method) + a = to_tensor((a, tensor_type)) + b = Tensor(b) + x0 = Tensor(x0) + m = to_tensor((m, tensor_type)) if m is not None else m + ms_output, _ = msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, + maxiter=maxiter, M=m, atol=atol, solve_method=solve_method) assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error) @@ -374,32 +386,57 @@ def test_gmres_against_scipy(n, dtype, error, restart, maxiter, preconditioner, @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @pytest.mark.parametrize('n', [3]) -@pytest.mark.parametrize('dtype,error', [(onp.float32, 1e-4)]) +@pytest.mark.parametrize('tensor_type, dtype, error', [('Tensor', onp.float64, 1e-5), ('Tensor', onp.float32, 1e-4), + ('CSRTensor', onp.float32, 1e-4)]) @pytest.mark.parametrize('preconditioner', ['random']) @pytest.mark.parametrize('solve_method', ['incremental', 'batched']) -def test_gmres_against_graph_scipy(n, dtype, error, preconditioner, solve_method): +def test_gmres_against_graph_scipy(n, tensor_type, dtype, error, preconditioner, solve_method): """ Feature: ALL TO ALL Description: test cases for [N x N] X [N X 1] Expectation: the result match scipy in graph """ + if not _is_valid_platform(tensor_type): + return + + # Input CSRTensor of gmres in mindspore graph mode is not supported, just ignored it. + if tensor_type == "CSRTensor": + return + + class TestNet(nn.Cell): + def __init__(self, solve_method): + super(TestNet, self).__init__() + self.solve_method = solve_method + + def construct(self, a, b, x0, tol, restart, maxiter, m, atol): + return msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, + atol=atol, solve_method=self.solve_method) + onp.random.seed(0) a = create_full_rank_matrix((n, n), dtype) b = onp.random.rand(n).astype(dtype) x0 = onp.zeros_like(b).astype(dtype) - M = _fetch_preconditioner(preconditioner, a) + m = _fetch_preconditioner(preconditioner, a) tol = float(onp.finfo(dtype=dtype).eps) atol = tol restart = n maxiter = None - scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol) + scipy_output, _ = osp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, M=m, atol=atol) # Graph Mode context.set_context(mode=context.GRAPH_MODE) - M = Tensor(M) if M is not None else M - ms_output, _ = msp.sparse.linalg.gmres(Tensor(a), Tensor(b), Tensor(x0), tol=tol, restart=restart, maxiter=maxiter, - M=M, atol=atol, solve_method=solve_method) + a = to_tensor((a, tensor_type)) + b = Tensor(b) + x0 = Tensor(x0) + m = to_tensor((m, tensor_type)) if m is not None else m + # Not in graph's construct + ms_output, _ = msp.sparse.linalg.gmres(a, b, x0, tol=tol, restart=restart, maxiter=maxiter, + M=m, atol=atol) assert onp.allclose(scipy_output, ms_output.asnumpy(), rtol=error, atol=error) + # With in graph's construct + ms_net_output, _ = TestNet(solve_method)(a, b, x0, tol, restart, maxiter, m, atol) + assert onp.allclose(scipy_output, ms_net_output.asnumpy(), rtol=error, atol=error) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training