Add CSRTensor support and test cases for cg method.

This commit is contained in:
hezhenhao1 2022-03-02 11:47:31 +08:00
parent c1036cbfaf
commit 885210b634
5 changed files with 205 additions and 47 deletions

View File

@ -21,7 +21,7 @@ from ..linalg import solve_triangular
from ..linalg import cho_factor, cho_solve
from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \
_sparse_check
from ..utils_const import _raise_value_error, _raise_type_error
from ..utils_const import _raise_value_error, _raise_type_error, _nullable_const
def gram_schmidt(Q, q):
@ -323,6 +323,45 @@ class CG(nn.Cell):
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
class CGv2(nn.Cell):
"""
This is a new version of CG, which contains all parameters in a graph.
"""
def __init__(self):
super(CGv2, self).__init__()
def construct(self, A, M, b, x0, tol, atol, maxiter):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
A = _normalize_matvec(A)
M = _normalize_matvec(M)
atol_ = mnp.maximum(atol, tol * _norm(b))
r = b - A(x0)
z = p = M(r)
rho = mnp.dot(r, z)
k = _INT_ZERO
x = x0
while k < maxiter and _norm(r) > atol_:
q = A(p)
alpha = rho / mnp.dot(p, q)
x = x + alpha * p
r = r - alpha * q
z = M(r)
rho_ = mnp.dot(r, z)
beta = rho_ / rho
p = z + beta * p
rho = rho_
k += 1
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None):
"""Use Conjugate Gradient iteration to solve the linear system:
@ -343,7 +382,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
- `cg` is not supported on Windows platform yet.
Args:
A (Union[Tensor, function]): 2D Tensor or function that calculates the linear
A (Union[Tensor, CSRTensor, function]): 2D Tensor, CSRTensor or function that calculates the linear
map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`.
As function, `A` must return Tensor with the same structure and shape as its input matrix.
b (Tensor): Right hand side of the linear system representing a single vector. Can be
@ -372,8 +411,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
TypeError: If `atol` is not float.
TypeError: If `maxiter` is not int.
ValueError: If `callback` is not None.
TypeError: If `A` is not Tensor or Function.
TypeError: If `M` is not None, Tensor or Function.
TypeError: If `A` is not Tensor, CSRTensor, or Function.
TypeError: If `M` is not None, Tensor, CSRTensor, or Function.
TypeError: If `b` is not Tensor.
TypeError: If `x0` is not None or Tensor.
ValueError: If `b` is not 1 or 2 dimension.
@ -411,9 +450,12 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
_type_check(func_name, atol, float, 'atol')
_type_check(func_name, maxiter, int, 'maxiter')
_value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
_sparse_check(func_name, A, M, b, x0)
A, M, b, x0 = _sparse_check(func_name, A, M, b, x0)
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
if not _nullable_const(A):
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
else:
x, info = CGv2()(A, M, b, x0, tol, atol, maxiter)
return x, info

View File

@ -17,7 +17,7 @@ from .. import ops
from .. import numpy as mnp
from ..numpy import where, zeros_like, dot, greater
from ..ops import functional as F
from ..common import Tensor
from ..common import Tensor, CSRTensor
from ..common import dtype as mstype
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
from ..ops.composite import GradOperation
@ -85,19 +85,22 @@ def _safe_normalize(x, threshold=None):
return normalized_x, norm
def sparse_dot(a, b):
"""Returns the dot product of CSRTensor and generic Tensor(vector)."""
b_aligned = F.reshape(b, (b.shape[0], -1))
res = F.csr_mv(a, b_aligned)
res = F.reshape(res, a.shape[:-1] + b.shape[1:])
return res
def _normalize_matvec(f):
"""Normalize an argument for computing matrix-vector products."""
if _callable_const(F.typeof(f)):
return f
if isinstance(f, Tensor):
if f.ndim != 2 or f.shape[0] != f.shape[1]:
_raise_value_error(
'linear operator must be a square matrix, but has shape: ', f.shape, ".")
return F.partial(dot, f)
_raise_value_error(
'linear operator must be either a function or Tensor: but got ', F.typeof(f), ".")
if isinstance(f, CSRTensor):
return F.partial(sparse_dot, f)
return f
@ -119,11 +122,11 @@ def _nd_transpose(a):
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
@ -132,24 +135,25 @@ def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'):
return _super_check((arg.dtype, arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
def _square_check(func_name, arg, arg_name='a'):
arg_shape = arg.shape
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True)
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True)
return func_name
return arg
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False):
arg1_shape, arg1_dtype = arg1.shape, arg1.dtype
arg2_shape, arg2_dtype = arg2.shape, arg2.dtype
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1)
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2)
_square_check(func_name, arg1, arg1_name)
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
return func_name
return arg1, arg2
def _sparse_check(func_name, a, m, b, x0):
"""Used for cg, bicgstab and gmres method."""
@ -163,14 +167,28 @@ def _sparse_check(func_name, a, m, b, x0):
if b.ndim != 1 or (b.ndim == 2 and b.shape[1] != 1):
_raise_value_error(
"For: '", func_name, "', the shape of b should be like (N,) or (N, 1), bug got ", b.shape, ".")
_super_check((b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64]),
(func_name, 'b', "data type"), "in", "attr", None, False)
_dtype_check(func_name, b, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b')
_super_check((b.dtype, x0.dtype), (func_name, 'b', 'x0', 'data type'), '==', 'match', None, True)
_super_check((b.shape, x0.shape), (func_name, 'b', 'x0', 'shape'), '==', 'match', None, True)
if not _callable_const(F.typeof(a)):
_solve_check(func_name, a, b, 'A', 'b', True)
def _check(arg, arg_name):
if _callable_const(F.typeof(arg)):
return arg
if not _callable_const(F.typeof(m)):
_solve_check(func_name, m, b, 'M', 'b', True)
return func_name
_solve_check(func_name, arg, b, arg_name, 'b', True)
if isinstance(arg, CSRTensor):
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name)
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name)
_dtype_check(func_name, arg.values, [mstype.float32], arg_name)
else:
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
if F.dtype(arg) in (mstype.int32, mstype.int64):
arg = F.cast(arg, mstype.float64)
return arg
a = _check(a, 'A')
m = _check(m, 'M')
if F.dtype(b) in (mstype.int32, mstype.int64):
b = F.cast(b, mstype.float64)
x0 = F.cast(x0, mstype.float64)
return a, m, b, x0

View File

@ -26,6 +26,15 @@ def _callable_const(x):
return isinstance(x, mstype.function_type)
@constexpr
def _nullable_const(x):
"""
Returns true if x is None. It's aim to check whether the call is within MindSpore graph.
Because in graph mode, x should be None in constexpr when x is a variable of MindSpore.
"""
return x is None
@constexpr
def _type_convert(new_type, obj):
"""

View File

@ -18,10 +18,11 @@ import numpy as onp
import scipy as osp
import scipy.sparse.linalg
import mindspore.nn as nn
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix
from mindspore.common import Tensor, CSRTensor
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix
def _fetch_preconditioner(preconditioner, A):
@ -46,38 +47,39 @@ def _fetch_preconditioner(preconditioner, A):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
@pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [1, 3])
def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
def test_cg_against_scipy(dtype, tol, shape, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for cg
Expectation: the result match scipy
"""
onp.random.seed(0)
dtype, tol = dtype_tol
A = create_sym_pos_matrix(shape, dtype)
a = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
M = _fetch_preconditioner(preconditioner, A)
osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
m = _fetch_preconditioner(preconditioner, a)
osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
A = Tensor(A)
a = Tensor(a)
b = Tensor(b)
M = Tensor(M) if M is not None else M
m = Tensor(m) if m is not None else m
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)
onp.testing.assert_allclose(osp_res[0], msp_res_dyn[0].asnumpy(), **kw)
onp.testing.assert_allclose(osp_res[0], msp_res_sta[0].asnumpy(), **kw)
assert osp_res[1] == msp_res_dyn[1].asnumpy().item()
assert osp_res[1] == msp_res_sta[1].asnumpy().item()
@pytest.mark.level0
@ -93,23 +95,97 @@ def test_cg_against_numpy(dtype, shape):
Expectation: the result match numpy
"""
onp.random.seed(0)
A = create_sym_pos_matrix(shape, dtype)
a = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
expected = onp.linalg.solve(A, b)
expected = onp.linalg.solve(a, b)
# using PYNATIVE MODE
context.set_context(mode=context.PYNATIVE_MODE)
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
# using GRAPH MODE
context.set_context(mode=context.GRAPH_MODE)
actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
kw = {"atol": 1e-5, "rtol": 1e-5}
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
@pytest.mark.parametrize('shape', [(7, 7)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
@pytest.mark.parametrize('maxiter', [3])
def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases for cg within Cell object
Expectation: the result match scipy
"""
context.set_context(mode=context.GRAPH_MODE)
class TestNet(nn.Cell):
def construct(self, a, b, m, maxiter, tol):
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
onp.random.seed(0)
a = create_sym_pos_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
m = _fetch_preconditioner(preconditioner, a)
osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
a = Tensor(a)
b = Tensor(b)
m = Tensor(m) if m is not None else m
msp_res = TestNet()(a, b, m, maxiter, tol)
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
assert osp_res[1] == msp_res[1].asnumpy().item()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5)])
@pytest.mark.parametrize('shape', [(7, 7)])
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'random'])
@pytest.mark.parametrize('maxiter', [3])
def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
"""
Feature: ALL TO ALL
Description: test cases of CSRTensor for cg
Expectation: the result match scipy.
"""
context.set_context(mode=context.GRAPH_MODE)
class TestNet(nn.Cell):
def construct(self, a, b, m, maxiter, tol):
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
onp.random.seed(0)
# scipy
a = create_sym_pos_sparse_matrix(shape, dtype)
b = onp.random.random(shape[:1]).astype(dtype)
m = _fetch_preconditioner(preconditioner, a)
osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
# mindspore
a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape)
b = Tensor(b)
m = Tensor(m) if m is not None else m
msp_res = TestNet()(a, b, m, maxiter, tol)
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
assert osp_res[1] == msp_res[1].asnumpy().item()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training

View File

@ -17,6 +17,7 @@ from typing import List
from functools import cmp_to_key
import numpy as onp
import scipy.sparse.linalg
from mindspore import Tensor
import mindspore.ops as ops
import mindspore.numpy as mnp
@ -98,6 +99,18 @@ def create_sym_pos_matrix(shape, dtype):
return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32):
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError(
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
n = shape[-1]
indptr = onp.arange(n + 1).astype(indice_dtype)
indices = onp.arange(n).astype(indice_dtype)
values = onp.random.random(n).astype(dtype)
return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape)
def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
# some utils
def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]: