From 885210b6344ea7319545b17e5c7cadf6644e2d03 Mon Sep 17 00:00:00 2001
From: hezhenhao1 <hezhenhao1@huawei.com>
Date: Wed, 2 Mar 2022 11:47:31 +0800
Subject: [PATCH] Add CSRTensor support and test cases for cg method.

---
 .../python/mindspore/scipy/sparse/linalg.py   |  54 ++++++++-
 mindspore/python/mindspore/scipy/utils.py     |  64 ++++++----
 .../python/mindspore/scipy/utils_const.py     |   9 ++
 tests/st/scipy_st/sparse/test_linalg.py       | 112 +++++++++++++++---
 tests/st/scipy_st/utils.py                    |  13 ++
 5 files changed, 205 insertions(+), 47 deletions(-)

diff --git a/mindspore/python/mindspore/scipy/sparse/linalg.py b/mindspore/python/mindspore/scipy/sparse/linalg.py
index 91a103c95cb..a3ea62b5b79 100644
--- a/mindspore/python/mindspore/scipy/sparse/linalg.py
+++ b/mindspore/python/mindspore/scipy/sparse/linalg.py
@@ -21,7 +21,7 @@ from ..linalg import solve_triangular
 from ..linalg import cho_factor, cho_solve
 from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \
     _sparse_check
-from ..utils_const import _raise_value_error, _raise_type_error
+from ..utils_const import _raise_value_error, _raise_type_error, _nullable_const
 
 
 def gram_schmidt(Q, q):
@@ -323,6 +323,45 @@ class CG(nn.Cell):
         return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
 
 
+class CGv2(nn.Cell):
+    """
+    This is a new version of CG, which contains all parameters in a graph.
+    """
+
+    def __init__(self):
+        super(CGv2, self).__init__()
+
+    def construct(self, A, M, b, x0, tol, atol, maxiter):
+        # Constant tensor which avoids loop unrolling
+        _INT_ZERO = _to_tensor(0)
+
+        A = _normalize_matvec(A)
+        M = _normalize_matvec(M)
+
+        atol_ = mnp.maximum(atol, tol * _norm(b))
+
+        r = b - A(x0)
+        z = p = M(r)
+        rho = mnp.dot(r, z)
+        k = _INT_ZERO
+        x = x0
+        while k < maxiter and _norm(r) > atol_:
+            q = A(p)
+            alpha = rho / mnp.dot(p, q)
+            x = x + alpha * p
+            r = r - alpha * q
+
+            z = M(r)
+            rho_ = mnp.dot(r, z)
+            beta = rho_ / rho
+            p = z + beta * p
+            rho = rho_
+
+            k += 1
+
+        return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
+
+
 def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None):
     """Use Conjugate Gradient iteration to solve the linear system:
 
@@ -343,7 +382,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
         - `cg` is not supported on Windows platform yet.
 
     Args:
-        A (Union[Tensor, function]): 2D Tensor or function that calculates the linear
+        A (Union[Tensor, CSRTensor, function]): 2D Tensor, CSRTensor or function that calculates the linear
             map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`.
             As function, `A` must return Tensor with the same structure and shape as its input matrix.
         b (Tensor): Right hand side of the linear system representing a single vector. Can be
@@ -372,8 +411,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
         TypeError: If `atol` is not float.
         TypeError: If `maxiter` is not int.
         ValueError: If `callback` is not None.
-        TypeError: If `A` is not Tensor or Function.
-        TypeError: If `M` is not None, Tensor or Function.
+        TypeError: If `A` is not Tensor, CSRTensor, or Function.
+        TypeError: If `M` is not None, Tensor, CSRTensor, or Function.
         TypeError: If `b` is not Tensor.
         TypeError: If `x0` is not None or Tensor.
         ValueError: If `b` is not 1 or 2 dimension.
@@ -411,9 +450,12 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
     _type_check(func_name, atol, float, 'atol')
     _type_check(func_name, maxiter, int, 'maxiter')
     _value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
-    _sparse_check(func_name, A, M, b, x0)
+    A, M, b, x0 = _sparse_check(func_name, A, M, b, x0)
 
-    x, info = CG(A, M)(b, x0, tol, atol, maxiter)
+    if not _nullable_const(A):
+        x, info = CG(A, M)(b, x0, tol, atol, maxiter)
+    else:
+        x, info = CGv2()(A, M, b, x0, tol, atol, maxiter)
     return x, info
 
 
diff --git a/mindspore/python/mindspore/scipy/utils.py b/mindspore/python/mindspore/scipy/utils.py
index a329cedc519..3e656768bdf 100644
--- a/mindspore/python/mindspore/scipy/utils.py
+++ b/mindspore/python/mindspore/scipy/utils.py
@@ -17,7 +17,7 @@ from .. import ops
 from .. import numpy as mnp
 from ..numpy import where, zeros_like, dot, greater
 from ..ops import functional as F
-from ..common import Tensor
+from ..common import Tensor, CSRTensor
 from ..common import dtype as mstype
 from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
 from ..ops.composite import GradOperation
@@ -85,19 +85,22 @@ def _safe_normalize(x, threshold=None):
     return normalized_x, norm
 
 
+def sparse_dot(a, b):
+    """Returns the dot product of CSRTensor and generic Tensor(vector)."""
+    b_aligned = F.reshape(b, (b.shape[0], -1))
+    res = F.csr_mv(a, b_aligned)
+    res = F.reshape(res, a.shape[:-1] + b.shape[1:])
+    return res
+
+
 def _normalize_matvec(f):
     """Normalize an argument for computing matrix-vector products."""
-    if _callable_const(F.typeof(f)):
-        return f
-
     if isinstance(f, Tensor):
-        if f.ndim != 2 or f.shape[0] != f.shape[1]:
-            _raise_value_error(
-                'linear operator must be a square matrix, but has shape: ', f.shape, ".")
         return F.partial(dot, f)
 
-    _raise_value_error(
-        'linear operator must be either a function or Tensor: but got ', F.typeof(f), ".")
+    if isinstance(f, CSRTensor):
+        return F.partial(sparse_dot, f)
+
     return f
 
 
@@ -119,11 +122,11 @@ def _nd_transpose(a):
 
 
 def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
-    return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
+    return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
 
 
 def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
-    return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
+    return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
 
 
 def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
@@ -132,24 +135,25 @@ def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
 
 
 def _dtype_check(func_name, arg, arg_dtype, arg_name='a'):
-    return _super_check((arg.dtype, arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
+    return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
 
 
 def _square_check(func_name, arg, arg_name='a'):
     arg_shape = arg.shape
     _super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True)
     _super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True)
-    return func_name
+    return arg
 
 
 def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False):
-    arg1_shape, arg1_dtype = arg1.shape, arg1.dtype
-    arg2_shape, arg2_dtype = arg2.shape, arg2.dtype
+    arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1)
+    arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2)
     _square_check(func_name, arg1, arg1_name)
     _super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
     _super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
     _super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
-    return func_name
+    return arg1, arg2
+
 
 def _sparse_check(func_name, a, m, b, x0):
     """Used for cg, bicgstab and gmres method."""
@@ -163,14 +167,28 @@ def _sparse_check(func_name, a, m, b, x0):
     if b.ndim != 1 or (b.ndim == 2 and b.shape[1] != 1):
         _raise_value_error(
             "For: '", func_name, "', the shape of b should be like (N,) or (N, 1), bug got ", b.shape, ".")
-    _super_check((b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64]),
-                 (func_name, 'b', "data type"), "in", "attr", None, False)
+    _dtype_check(func_name, b, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b')
     _super_check((b.dtype, x0.dtype), (func_name, 'b', 'x0', 'data type'), '==', 'match', None, True)
     _super_check((b.shape, x0.shape), (func_name, 'b', 'x0', 'shape'), '==', 'match', None, True)
 
-    if not _callable_const(F.typeof(a)):
-        _solve_check(func_name, a, b, 'A', 'b', True)
+    def _check(arg, arg_name):
+        if _callable_const(F.typeof(arg)):
+            return arg
 
-    if not _callable_const(F.typeof(m)):
-        _solve_check(func_name, m, b, 'M', 'b', True)
-    return func_name
+        _solve_check(func_name, arg, b, arg_name, 'b', True)
+        if isinstance(arg, CSRTensor):
+            _dtype_check(func_name, arg.indptr, [mstype.int32], arg_name)
+            _dtype_check(func_name, arg.indices, [mstype.int32], arg_name)
+            _dtype_check(func_name, arg.values, [mstype.float32], arg_name)
+        else:
+            _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
+            if F.dtype(arg) in (mstype.int32, mstype.int64):
+                arg = F.cast(arg, mstype.float64)
+        return arg
+
+    a = _check(a, 'A')
+    m = _check(m, 'M')
+    if F.dtype(b) in (mstype.int32, mstype.int64):
+        b = F.cast(b, mstype.float64)
+        x0 = F.cast(x0, mstype.float64)
+    return a, m, b, x0
diff --git a/mindspore/python/mindspore/scipy/utils_const.py b/mindspore/python/mindspore/scipy/utils_const.py
index 618378e6198..8a8f05d268b 100644
--- a/mindspore/python/mindspore/scipy/utils_const.py
+++ b/mindspore/python/mindspore/scipy/utils_const.py
@@ -26,6 +26,15 @@ def _callable_const(x):
     return isinstance(x, mstype.function_type)
 
 
+@constexpr
+def _nullable_const(x):
+    """
+    Returns true if x is None. It's aim to check whether the call is within MindSpore graph.
+    Because in graph mode, x should be None in constexpr when x is a variable of MindSpore.
+    """
+    return x is None
+
+
 @constexpr
 def _type_convert(new_type, obj):
     """
diff --git a/tests/st/scipy_st/sparse/test_linalg.py b/tests/st/scipy_st/sparse/test_linalg.py
index 397c75fe054..15900e5f1e4 100644
--- a/tests/st/scipy_st/sparse/test_linalg.py
+++ b/tests/st/scipy_st/sparse/test_linalg.py
@@ -18,10 +18,11 @@ import numpy as onp
 import scipy as osp
 import scipy.sparse.linalg
 
+import mindspore.nn as nn
 import mindspore.scipy as msp
 from mindspore import context
-from mindspore.common import Tensor
-from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix
+from mindspore.common import Tensor, CSRTensor
+from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix
 
 
 def _fetch_preconditioner(preconditioner, A):
@@ -46,38 +47,39 @@ def _fetch_preconditioner(preconditioner, A):
 @pytest.mark.platform_x86_gpu_training
 @pytest.mark.platform_x86_cpu
 @pytest.mark.env_onecard
-@pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
+@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
 @pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
 @pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
 @pytest.mark.parametrize('maxiter', [1, 3])
-def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
+def test_cg_against_scipy(dtype, tol, shape, preconditioner, maxiter):
     """
     Feature: ALL TO ALL
     Description: test cases for cg
     Expectation: the result match scipy
     """
     onp.random.seed(0)
-    dtype, tol = dtype_tol
-    A = create_sym_pos_matrix(shape, dtype)
+    a = create_sym_pos_matrix(shape, dtype)
     b = onp.random.random(shape[:1]).astype(dtype)
-    M = _fetch_preconditioner(preconditioner, A)
-    osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
+    m = _fetch_preconditioner(preconditioner, a)
+    osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
 
-    A = Tensor(A)
+    a = Tensor(a)
     b = Tensor(b)
-    M = Tensor(M) if M is not None else M
+    m = Tensor(m) if m is not None else m
 
     # using PYNATIVE MODE
     context.set_context(mode=context.PYNATIVE_MODE)
-    msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
+    msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
 
     # using GRAPH MODE
     context.set_context(mode=context.GRAPH_MODE)
-    msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
+    msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
 
     kw = {"atol": tol, "rtol": tol}
-    onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
-    onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)
+    onp.testing.assert_allclose(osp_res[0], msp_res_dyn[0].asnumpy(), **kw)
+    onp.testing.assert_allclose(osp_res[0], msp_res_sta[0].asnumpy(), **kw)
+    assert osp_res[1] == msp_res_dyn[1].asnumpy().item()
+    assert osp_res[1] == msp_res_sta[1].asnumpy().item()
 
 
 @pytest.mark.level0
@@ -93,23 +95,97 @@ def test_cg_against_numpy(dtype, shape):
     Expectation: the result match numpy
     """
     onp.random.seed(0)
-    A = create_sym_pos_matrix(shape, dtype)
+    a = create_sym_pos_matrix(shape, dtype)
     b = onp.random.random(shape[:1]).astype(dtype)
-    expected = onp.linalg.solve(A, b)
+    expected = onp.linalg.solve(a, b)
 
     # using PYNATIVE MODE
     context.set_context(mode=context.PYNATIVE_MODE)
-    actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
+    actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
 
     # using GRAPH MODE
     context.set_context(mode=context.GRAPH_MODE)
-    actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
+    actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
 
     kw = {"atol": 1e-5, "rtol": 1e-5}
     onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
     onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
 
 
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.platform_x86_cpu
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
+@pytest.mark.parametrize('shape', [(7, 7)])
+@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
+@pytest.mark.parametrize('maxiter', [3])
+def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
+    """
+    Feature: ALL TO ALL
+    Description: test cases for cg within Cell object
+    Expectation: the result match scipy
+    """
+    context.set_context(mode=context.GRAPH_MODE)
+
+    class TestNet(nn.Cell):
+        def construct(self, a, b, m, maxiter, tol):
+            return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
+
+    onp.random.seed(0)
+    a = create_sym_pos_matrix(shape, dtype)
+    b = onp.random.random(shape[:1]).astype(dtype)
+    m = _fetch_preconditioner(preconditioner, a)
+    osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
+
+    a = Tensor(a)
+    b = Tensor(b)
+    m = Tensor(m) if m is not None else m
+    msp_res = TestNet()(a, b, m, maxiter, tol)
+
+    kw = {"atol": tol, "rtol": tol}
+    onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
+    assert osp_res[1] == msp_res[1].asnumpy().item()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5)])
+@pytest.mark.parametrize('shape', [(7, 7)])
+@pytest.mark.parametrize('preconditioner', [None, 'identity', 'random'])
+@pytest.mark.parametrize('maxiter', [3])
+def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
+    """
+    Feature: ALL TO ALL
+    Description: test cases of CSRTensor for cg
+    Expectation: the result match scipy.
+    """
+    context.set_context(mode=context.GRAPH_MODE)
+
+    class TestNet(nn.Cell):
+        def construct(self, a, b, m, maxiter, tol):
+            return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
+
+    onp.random.seed(0)
+
+    # scipy
+    a = create_sym_pos_sparse_matrix(shape, dtype)
+    b = onp.random.random(shape[:1]).astype(dtype)
+    m = _fetch_preconditioner(preconditioner, a)
+    osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
+
+    # mindspore
+    a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape)
+    b = Tensor(b)
+    m = Tensor(m) if m is not None else m
+    msp_res = TestNet()(a, b, m, maxiter, tol)
+
+    kw = {"atol": tol, "rtol": tol}
+    onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
+    assert osp_res[1] == msp_res[1].asnumpy().item()
+
+
 @pytest.mark.level0
 @pytest.mark.platform_x86_cpu
 @pytest.mark.platform_x86_gpu_training
diff --git a/tests/st/scipy_st/utils.py b/tests/st/scipy_st/utils.py
index 73e2f935cc5..ceff2452b59 100644
--- a/tests/st/scipy_st/utils.py
+++ b/tests/st/scipy_st/utils.py
@@ -17,6 +17,7 @@ from typing import List
 from functools import cmp_to_key
 
 import numpy as onp
+import scipy.sparse.linalg
 from mindspore import Tensor
 import mindspore.ops as ops
 import mindspore.numpy as mnp
@@ -98,6 +99,18 @@ def create_sym_pos_matrix(shape, dtype):
     return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
 
 
+def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32):
+    if len(shape) != 2 or shape[0] != shape[1]:
+        raise ValueError(
+            'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
+
+    n = shape[-1]
+    indptr = onp.arange(n + 1).astype(indice_dtype)
+    indices = onp.arange(n).astype(indice_dtype)
+    values = onp.random.random(n).astype(dtype)
+    return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape)
+
+
 def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
     # some utils
     def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]: