From 885210b6344ea7319545b17e5c7cadf6644e2d03 Mon Sep 17 00:00:00 2001 From: hezhenhao1 <hezhenhao1@huawei.com> Date: Wed, 2 Mar 2022 11:47:31 +0800 Subject: [PATCH] Add CSRTensor support and test cases for cg method. --- .../python/mindspore/scipy/sparse/linalg.py | 54 ++++++++- mindspore/python/mindspore/scipy/utils.py | 64 ++++++---- .../python/mindspore/scipy/utils_const.py | 9 ++ tests/st/scipy_st/sparse/test_linalg.py | 112 +++++++++++++++--- tests/st/scipy_st/utils.py | 13 ++ 5 files changed, 205 insertions(+), 47 deletions(-) diff --git a/mindspore/python/mindspore/scipy/sparse/linalg.py b/mindspore/python/mindspore/scipy/sparse/linalg.py index 91a103c95cb..a3ea62b5b79 100644 --- a/mindspore/python/mindspore/scipy/sparse/linalg.py +++ b/mindspore/python/mindspore/scipy/sparse/linalg.py @@ -21,7 +21,7 @@ from ..linalg import solve_triangular from ..linalg import cho_factor, cho_solve from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \ _sparse_check -from ..utils_const import _raise_value_error, _raise_type_error +from ..utils_const import _raise_value_error, _raise_type_error, _nullable_const def gram_schmidt(Q, q): @@ -323,6 +323,45 @@ class CG(nn.Cell): return x, F.select(_norm(r) > atol_, k, _INT_ZERO) +class CGv2(nn.Cell): + """ + This is a new version of CG, which contains all parameters in a graph. + """ + + def __init__(self): + super(CGv2, self).__init__() + + def construct(self, A, M, b, x0, tol, atol, maxiter): + # Constant tensor which avoids loop unrolling + _INT_ZERO = _to_tensor(0) + + A = _normalize_matvec(A) + M = _normalize_matvec(M) + + atol_ = mnp.maximum(atol, tol * _norm(b)) + + r = b - A(x0) + z = p = M(r) + rho = mnp.dot(r, z) + k = _INT_ZERO + x = x0 + while k < maxiter and _norm(r) > atol_: + q = A(p) + alpha = rho / mnp.dot(p, q) + x = x + alpha * p + r = r - alpha * q + + z = M(r) + rho_ = mnp.dot(r, z) + beta = rho_ / rho + p = z + beta * p + rho = rho_ + + k += 1 + + return x, F.select(_norm(r) > atol_, k, _INT_ZERO) + + def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None): """Use Conjugate Gradient iteration to solve the linear system: @@ -343,7 +382,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None - `cg` is not supported on Windows platform yet. Args: - A (Union[Tensor, function]): 2D Tensor or function that calculates the linear + A (Union[Tensor, CSRTensor, function]): 2D Tensor, CSRTensor or function that calculates the linear map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`. As function, `A` must return Tensor with the same structure and shape as its input matrix. b (Tensor): Right hand side of the linear system representing a single vector. Can be @@ -372,8 +411,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None TypeError: If `atol` is not float. TypeError: If `maxiter` is not int. ValueError: If `callback` is not None. - TypeError: If `A` is not Tensor or Function. - TypeError: If `M` is not None, Tensor or Function. + TypeError: If `A` is not Tensor, CSRTensor, or Function. + TypeError: If `M` is not None, Tensor, CSRTensor, or Function. TypeError: If `b` is not Tensor. TypeError: If `x0` is not None or Tensor. ValueError: If `b` is not 1 or 2 dimension. @@ -411,9 +450,12 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None _type_check(func_name, atol, float, 'atol') _type_check(func_name, maxiter, int, 'maxiter') _value_check(func_name, callback, None, 'callback', op='is', fmt='todo') - _sparse_check(func_name, A, M, b, x0) + A, M, b, x0 = _sparse_check(func_name, A, M, b, x0) - x, info = CG(A, M)(b, x0, tol, atol, maxiter) + if not _nullable_const(A): + x, info = CG(A, M)(b, x0, tol, atol, maxiter) + else: + x, info = CGv2()(A, M, b, x0, tol, atol, maxiter) return x, info diff --git a/mindspore/python/mindspore/scipy/utils.py b/mindspore/python/mindspore/scipy/utils.py index a329cedc519..3e656768bdf 100644 --- a/mindspore/python/mindspore/scipy/utils.py +++ b/mindspore/python/mindspore/scipy/utils.py @@ -17,7 +17,7 @@ from .. import ops from .. import numpy as mnp from ..numpy import where, zeros_like, dot, greater from ..ops import functional as F -from ..common import Tensor +from ..common import Tensor, CSRTensor from ..common import dtype as mstype from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack from ..ops.composite import GradOperation @@ -85,19 +85,22 @@ def _safe_normalize(x, threshold=None): return normalized_x, norm +def sparse_dot(a, b): + """Returns the dot product of CSRTensor and generic Tensor(vector).""" + b_aligned = F.reshape(b, (b.shape[0], -1)) + res = F.csr_mv(a, b_aligned) + res = F.reshape(res, a.shape[:-1] + b.shape[1:]) + return res + + def _normalize_matvec(f): """Normalize an argument for computing matrix-vector products.""" - if _callable_const(F.typeof(f)): - return f - if isinstance(f, Tensor): - if f.ndim != 2 or f.shape[0] != f.shape[1]: - _raise_value_error( - 'linear operator must be a square matrix, but has shape: ', f.shape, ".") return F.partial(dot, f) - _raise_value_error( - 'linear operator must be either a function or Tensor: but got ', F.typeof(f), ".") + if isinstance(f, CSRTensor): + return F.partial(sparse_dot, f) + return f @@ -119,11 +122,11 @@ def _nd_transpose(a): def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): - return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) + return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): - return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False) + return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): @@ -132,24 +135,25 @@ def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): - return _super_check((arg.dtype, arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) + return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) def _square_check(func_name, arg, arg_name='a'): arg_shape = arg.shape _super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) _super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) - return func_name + return arg def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): - arg1_shape, arg1_dtype = arg1.shape, arg1.dtype - arg2_shape, arg2_dtype = arg2.shape, arg2.dtype + arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) + arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) _square_check(func_name, arg1, arg1_name) _super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) _super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) _super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) - return func_name + return arg1, arg2 + def _sparse_check(func_name, a, m, b, x0): """Used for cg, bicgstab and gmres method.""" @@ -163,14 +167,28 @@ def _sparse_check(func_name, a, m, b, x0): if b.ndim != 1 or (b.ndim == 2 and b.shape[1] != 1): _raise_value_error( "For: '", func_name, "', the shape of b should be like (N,) or (N, 1), bug got ", b.shape, ".") - _super_check((b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64]), - (func_name, 'b', "data type"), "in", "attr", None, False) + _dtype_check(func_name, b, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b') _super_check((b.dtype, x0.dtype), (func_name, 'b', 'x0', 'data type'), '==', 'match', None, True) _super_check((b.shape, x0.shape), (func_name, 'b', 'x0', 'shape'), '==', 'match', None, True) - if not _callable_const(F.typeof(a)): - _solve_check(func_name, a, b, 'A', 'b', True) + def _check(arg, arg_name): + if _callable_const(F.typeof(arg)): + return arg - if not _callable_const(F.typeof(m)): - _solve_check(func_name, m, b, 'M', 'b', True) - return func_name + _solve_check(func_name, arg, b, arg_name, 'b', True) + if isinstance(arg, CSRTensor): + _dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) + _dtype_check(func_name, arg.indices, [mstype.int32], arg_name) + _dtype_check(func_name, arg.values, [mstype.float32], arg_name) + else: + _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) + if F.dtype(arg) in (mstype.int32, mstype.int64): + arg = F.cast(arg, mstype.float64) + return arg + + a = _check(a, 'A') + m = _check(m, 'M') + if F.dtype(b) in (mstype.int32, mstype.int64): + b = F.cast(b, mstype.float64) + x0 = F.cast(x0, mstype.float64) + return a, m, b, x0 diff --git a/mindspore/python/mindspore/scipy/utils_const.py b/mindspore/python/mindspore/scipy/utils_const.py index 618378e6198..8a8f05d268b 100644 --- a/mindspore/python/mindspore/scipy/utils_const.py +++ b/mindspore/python/mindspore/scipy/utils_const.py @@ -26,6 +26,15 @@ def _callable_const(x): return isinstance(x, mstype.function_type) +@constexpr +def _nullable_const(x): + """ + Returns true if x is None. It's aim to check whether the call is within MindSpore graph. + Because in graph mode, x should be None in constexpr when x is a variable of MindSpore. + """ + return x is None + + @constexpr def _type_convert(new_type, obj): """ diff --git a/tests/st/scipy_st/sparse/test_linalg.py b/tests/st/scipy_st/sparse/test_linalg.py index 397c75fe054..15900e5f1e4 100644 --- a/tests/st/scipy_st/sparse/test_linalg.py +++ b/tests/st/scipy_st/sparse/test_linalg.py @@ -18,10 +18,11 @@ import numpy as onp import scipy as osp import scipy.sparse.linalg +import mindspore.nn as nn import mindspore.scipy as msp from mindspore import context -from mindspore.common import Tensor -from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix +from mindspore.common import Tensor, CSRTensor +from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix def _fetch_preconditioner(preconditioner, A): @@ -46,38 +47,39 @@ def _fetch_preconditioner(preconditioner, A): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -@pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)]) +@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)]) @pytest.mark.parametrize('shape', [(4, 4), (7, 7)]) @pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random']) @pytest.mark.parametrize('maxiter', [1, 3]) -def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter): +def test_cg_against_scipy(dtype, tol, shape, preconditioner, maxiter): """ Feature: ALL TO ALL Description: test cases for cg Expectation: the result match scipy """ onp.random.seed(0) - dtype, tol = dtype_tol - A = create_sym_pos_matrix(shape, dtype) + a = create_sym_pos_matrix(shape, dtype) b = onp.random.random(shape[:1]).astype(dtype) - M = _fetch_preconditioner(preconditioner, A) - osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0] + m = _fetch_preconditioner(preconditioner, a) + osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) - A = Tensor(A) + a = Tensor(a) b = Tensor(b) - M = Tensor(M) if M is not None else M + m = Tensor(m) if m is not None else m # using PYNATIVE MODE context.set_context(mode=context.PYNATIVE_MODE) - msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0] + msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) # using GRAPH MODE context.set_context(mode=context.GRAPH_MODE) - msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0] + msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) kw = {"atol": tol, "rtol": tol} - onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw) - onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw) + onp.testing.assert_allclose(osp_res[0], msp_res_dyn[0].asnumpy(), **kw) + onp.testing.assert_allclose(osp_res[0], msp_res_sta[0].asnumpy(), **kw) + assert osp_res[1] == msp_res_dyn[1].asnumpy().item() + assert osp_res[1] == msp_res_sta[1].asnumpy().item() @pytest.mark.level0 @@ -93,23 +95,97 @@ def test_cg_against_numpy(dtype, shape): Expectation: the result match numpy """ onp.random.seed(0) - A = create_sym_pos_matrix(shape, dtype) + a = create_sym_pos_matrix(shape, dtype) b = onp.random.random(shape[:1]).astype(dtype) - expected = onp.linalg.solve(A, b) + expected = onp.linalg.solve(a, b) # using PYNATIVE MODE context.set_context(mode=context.PYNATIVE_MODE) - actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b)) + actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b)) # using GRAPH MODE context.set_context(mode=context.GRAPH_MODE) - actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b)) + actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b)) kw = {"atol": 1e-5, "rtol": 1e-5} onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw) onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)]) +@pytest.mark.parametrize('shape', [(7, 7)]) +@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random']) +@pytest.mark.parametrize('maxiter', [3]) +def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter): + """ + Feature: ALL TO ALL + Description: test cases for cg within Cell object + Expectation: the result match scipy + """ + context.set_context(mode=context.GRAPH_MODE) + + class TestNet(nn.Cell): + def construct(self, a, b, m, maxiter, tol): + return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) + + onp.random.seed(0) + a = create_sym_pos_matrix(shape, dtype) + b = onp.random.random(shape[:1]).astype(dtype) + m = _fetch_preconditioner(preconditioner, a) + osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) + + a = Tensor(a) + b = Tensor(b) + m = Tensor(m) if m is not None else m + msp_res = TestNet()(a, b, m, maxiter, tol) + + kw = {"atol": tol, "rtol": tol} + onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw) + assert osp_res[1] == msp_res[1].asnumpy().item() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5)]) +@pytest.mark.parametrize('shape', [(7, 7)]) +@pytest.mark.parametrize('preconditioner', [None, 'identity', 'random']) +@pytest.mark.parametrize('maxiter', [3]) +def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter): + """ + Feature: ALL TO ALL + Description: test cases of CSRTensor for cg + Expectation: the result match scipy. + """ + context.set_context(mode=context.GRAPH_MODE) + + class TestNet(nn.Cell): + def construct(self, a, b, m, maxiter, tol): + return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) + + onp.random.seed(0) + + # scipy + a = create_sym_pos_sparse_matrix(shape, dtype) + b = onp.random.random(shape[:1]).astype(dtype) + m = _fetch_preconditioner(preconditioner, a) + osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) + + # mindspore + a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape) + b = Tensor(b) + m = Tensor(m) if m is not None else m + msp_res = TestNet()(a, b, m, maxiter, tol) + + kw = {"atol": tol, "rtol": tol} + onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw) + assert osp_res[1] == msp_res[1].asnumpy().item() + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_gpu_training diff --git a/tests/st/scipy_st/utils.py b/tests/st/scipy_st/utils.py index 73e2f935cc5..ceff2452b59 100644 --- a/tests/st/scipy_st/utils.py +++ b/tests/st/scipy_st/utils.py @@ -17,6 +17,7 @@ from typing import List from functools import cmp_to_key import numpy as onp +import scipy.sparse.linalg from mindspore import Tensor import mindspore.ops as ops import mindspore.numpy as mnp @@ -98,6 +99,18 @@ def create_sym_pos_matrix(shape, dtype): return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype) +def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32): + if len(shape) != 2 or shape[0] != shape[1]: + raise ValueError( + 'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape) + + n = shape[-1] + indptr = onp.arange(n + 1).astype(indice_dtype) + indices = onp.arange(n).astype(indice_dtype) + values = onp.random.random(n).astype(dtype) + return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape) + + def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate): # some utils def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]: