forked from mindspore-Ecosystem/mindspore
Add CSRTensor support and test cases for cg method.
This commit is contained in:
parent
c1036cbfaf
commit
885210b634
|
@ -21,7 +21,7 @@ from ..linalg import solve_triangular
|
|||
from ..linalg import cho_factor, cho_solve
|
||||
from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \
|
||||
_sparse_check
|
||||
from ..utils_const import _raise_value_error, _raise_type_error
|
||||
from ..utils_const import _raise_value_error, _raise_type_error, _nullable_const
|
||||
|
||||
|
||||
def gram_schmidt(Q, q):
|
||||
|
@ -323,6 +323,45 @@ class CG(nn.Cell):
|
|||
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
|
||||
|
||||
|
||||
class CGv2(nn.Cell):
|
||||
"""
|
||||
This is a new version of CG, which contains all parameters in a graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CGv2, self).__init__()
|
||||
|
||||
def construct(self, A, M, b, x0, tol, atol, maxiter):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
A = _normalize_matvec(A)
|
||||
M = _normalize_matvec(M)
|
||||
|
||||
atol_ = mnp.maximum(atol, tol * _norm(b))
|
||||
|
||||
r = b - A(x0)
|
||||
z = p = M(r)
|
||||
rho = mnp.dot(r, z)
|
||||
k = _INT_ZERO
|
||||
x = x0
|
||||
while k < maxiter and _norm(r) > atol_:
|
||||
q = A(p)
|
||||
alpha = rho / mnp.dot(p, q)
|
||||
x = x + alpha * p
|
||||
r = r - alpha * q
|
||||
|
||||
z = M(r)
|
||||
rho_ = mnp.dot(r, z)
|
||||
beta = rho_ / rho
|
||||
p = z + beta * p
|
||||
rho = rho_
|
||||
|
||||
k += 1
|
||||
|
||||
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None):
|
||||
"""Use Conjugate Gradient iteration to solve the linear system:
|
||||
|
||||
|
@ -343,7 +382,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
|
|||
- `cg` is not supported on Windows platform yet.
|
||||
|
||||
Args:
|
||||
A (Union[Tensor, function]): 2D Tensor or function that calculates the linear
|
||||
A (Union[Tensor, CSRTensor, function]): 2D Tensor, CSRTensor or function that calculates the linear
|
||||
map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`.
|
||||
As function, `A` must return Tensor with the same structure and shape as its input matrix.
|
||||
b (Tensor): Right hand side of the linear system representing a single vector. Can be
|
||||
|
@ -372,8 +411,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
|
|||
TypeError: If `atol` is not float.
|
||||
TypeError: If `maxiter` is not int.
|
||||
ValueError: If `callback` is not None.
|
||||
TypeError: If `A` is not Tensor or Function.
|
||||
TypeError: If `M` is not None, Tensor or Function.
|
||||
TypeError: If `A` is not Tensor, CSRTensor, or Function.
|
||||
TypeError: If `M` is not None, Tensor, CSRTensor, or Function.
|
||||
TypeError: If `b` is not Tensor.
|
||||
TypeError: If `x0` is not None or Tensor.
|
||||
ValueError: If `b` is not 1 or 2 dimension.
|
||||
|
@ -411,9 +450,12 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
|
|||
_type_check(func_name, atol, float, 'atol')
|
||||
_type_check(func_name, maxiter, int, 'maxiter')
|
||||
_value_check(func_name, callback, None, 'callback', op='is', fmt='todo')
|
||||
_sparse_check(func_name, A, M, b, x0)
|
||||
A, M, b, x0 = _sparse_check(func_name, A, M, b, x0)
|
||||
|
||||
if not _nullable_const(A):
|
||||
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
|
||||
else:
|
||||
x, info = CGv2()(A, M, b, x0, tol, atol, maxiter)
|
||||
return x, info
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from .. import ops
|
|||
from .. import numpy as mnp
|
||||
from ..numpy import where, zeros_like, dot, greater
|
||||
from ..ops import functional as F
|
||||
from ..common import Tensor
|
||||
from ..common import Tensor, CSRTensor
|
||||
from ..common import dtype as mstype
|
||||
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
|
||||
from ..ops.composite import GradOperation
|
||||
|
@ -85,19 +85,22 @@ def _safe_normalize(x, threshold=None):
|
|||
return normalized_x, norm
|
||||
|
||||
|
||||
def sparse_dot(a, b):
|
||||
"""Returns the dot product of CSRTensor and generic Tensor(vector)."""
|
||||
b_aligned = F.reshape(b, (b.shape[0], -1))
|
||||
res = F.csr_mv(a, b_aligned)
|
||||
res = F.reshape(res, a.shape[:-1] + b.shape[1:])
|
||||
return res
|
||||
|
||||
|
||||
def _normalize_matvec(f):
|
||||
"""Normalize an argument for computing matrix-vector products."""
|
||||
if _callable_const(F.typeof(f)):
|
||||
return f
|
||||
|
||||
if isinstance(f, Tensor):
|
||||
if f.ndim != 2 or f.shape[0] != f.shape[1]:
|
||||
_raise_value_error(
|
||||
'linear operator must be a square matrix, but has shape: ', f.shape, ".")
|
||||
return F.partial(dot, f)
|
||||
|
||||
_raise_value_error(
|
||||
'linear operator must be either a function or Tensor: but got ', F.typeof(f), ".")
|
||||
if isinstance(f, CSRTensor):
|
||||
return F.partial(sparse_dot, f)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
|
@ -119,11 +122,11 @@ def _nd_transpose(a):
|
|||
|
||||
|
||||
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
|
||||
return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
|
||||
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
|
||||
|
||||
|
||||
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
|
||||
return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
|
||||
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
|
||||
|
||||
|
||||
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
|
||||
|
@ -132,24 +135,25 @@ def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
|
|||
|
||||
|
||||
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'):
|
||||
return _super_check((arg.dtype, arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
|
||||
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
|
||||
|
||||
|
||||
def _square_check(func_name, arg, arg_name='a'):
|
||||
arg_shape = arg.shape
|
||||
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True)
|
||||
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True)
|
||||
return func_name
|
||||
return arg
|
||||
|
||||
|
||||
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False):
|
||||
arg1_shape, arg1_dtype = arg1.shape, arg1.dtype
|
||||
arg2_shape, arg2_dtype = arg2.shape, arg2.dtype
|
||||
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1)
|
||||
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2)
|
||||
_square_check(func_name, arg1, arg1_name)
|
||||
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
|
||||
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
|
||||
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
|
||||
return func_name
|
||||
return arg1, arg2
|
||||
|
||||
|
||||
def _sparse_check(func_name, a, m, b, x0):
|
||||
"""Used for cg, bicgstab and gmres method."""
|
||||
|
@ -163,14 +167,28 @@ def _sparse_check(func_name, a, m, b, x0):
|
|||
if b.ndim != 1 or (b.ndim == 2 and b.shape[1] != 1):
|
||||
_raise_value_error(
|
||||
"For: '", func_name, "', the shape of b should be like (N,) or (N, 1), bug got ", b.shape, ".")
|
||||
_super_check((b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64]),
|
||||
(func_name, 'b', "data type"), "in", "attr", None, False)
|
||||
_dtype_check(func_name, b, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b')
|
||||
_super_check((b.dtype, x0.dtype), (func_name, 'b', 'x0', 'data type'), '==', 'match', None, True)
|
||||
_super_check((b.shape, x0.shape), (func_name, 'b', 'x0', 'shape'), '==', 'match', None, True)
|
||||
|
||||
if not _callable_const(F.typeof(a)):
|
||||
_solve_check(func_name, a, b, 'A', 'b', True)
|
||||
def _check(arg, arg_name):
|
||||
if _callable_const(F.typeof(arg)):
|
||||
return arg
|
||||
|
||||
if not _callable_const(F.typeof(m)):
|
||||
_solve_check(func_name, m, b, 'M', 'b', True)
|
||||
return func_name
|
||||
_solve_check(func_name, arg, b, arg_name, 'b', True)
|
||||
if isinstance(arg, CSRTensor):
|
||||
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name)
|
||||
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name)
|
||||
_dtype_check(func_name, arg.values, [mstype.float32], arg_name)
|
||||
else:
|
||||
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
|
||||
if F.dtype(arg) in (mstype.int32, mstype.int64):
|
||||
arg = F.cast(arg, mstype.float64)
|
||||
return arg
|
||||
|
||||
a = _check(a, 'A')
|
||||
m = _check(m, 'M')
|
||||
if F.dtype(b) in (mstype.int32, mstype.int64):
|
||||
b = F.cast(b, mstype.float64)
|
||||
x0 = F.cast(x0, mstype.float64)
|
||||
return a, m, b, x0
|
||||
|
|
|
@ -26,6 +26,15 @@ def _callable_const(x):
|
|||
return isinstance(x, mstype.function_type)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _nullable_const(x):
|
||||
"""
|
||||
Returns true if x is None. It's aim to check whether the call is within MindSpore graph.
|
||||
Because in graph mode, x should be None in constexpr when x is a variable of MindSpore.
|
||||
"""
|
||||
return x is None
|
||||
|
||||
|
||||
@constexpr
|
||||
def _type_convert(new_type, obj):
|
||||
"""
|
||||
|
|
|
@ -18,10 +18,11 @@ import numpy as onp
|
|||
import scipy as osp
|
||||
import scipy.sparse.linalg
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.scipy as msp
|
||||
from mindspore import context
|
||||
from mindspore.common import Tensor
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix
|
||||
from mindspore.common import Tensor, CSRTensor
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix
|
||||
|
||||
|
||||
def _fetch_preconditioner(preconditioner, A):
|
||||
|
@ -46,38 +47,39 @@ def _fetch_preconditioner(preconditioner, A):
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
|
||||
@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (7, 7)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [1, 3])
|
||||
def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter):
|
||||
def test_cg_against_scipy(dtype, tol, shape, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cg
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
dtype, tol = dtype_tol
|
||||
A = create_sym_pos_matrix(shape, dtype)
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
M = _fetch_preconditioner(preconditioner, A)
|
||||
osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
m = _fetch_preconditioner(preconditioner, a)
|
||||
osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
A = Tensor(A)
|
||||
a = Tensor(a)
|
||||
b = Tensor(b)
|
||||
M = Tensor(M) if M is not None else M
|
||||
m = Tensor(m) if m is not None else m
|
||||
|
||||
# using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
# using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0]
|
||||
msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(osp_res[0], msp_res_dyn[0].asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(osp_res[0], msp_res_sta[0].asnumpy(), **kw)
|
||||
assert osp_res[1] == msp_res_dyn[1].asnumpy().item()
|
||||
assert osp_res[1] == msp_res_sta[1].asnumpy().item()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -93,23 +95,97 @@ def test_cg_against_numpy(dtype, shape):
|
|||
Expectation: the result match numpy
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
A = create_sym_pos_matrix(shape, dtype)
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
expected = onp.linalg.solve(A, b)
|
||||
expected = onp.linalg.solve(a, b)
|
||||
|
||||
# using PYNATIVE MODE
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
|
||||
actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
|
||||
|
||||
# using GRAPH MODE
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b))
|
||||
actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b))
|
||||
|
||||
kw = {"atol": 1e-5, "rtol": 1e-5}
|
||||
onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)])
|
||||
@pytest.mark.parametrize('shape', [(7, 7)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [3])
|
||||
def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for cg within Cell object
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
def construct(self, a, b, m, maxiter, tol):
|
||||
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
onp.random.seed(0)
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
m = _fetch_preconditioner(preconditioner, a)
|
||||
osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
a = Tensor(a)
|
||||
b = Tensor(b)
|
||||
m = Tensor(m) if m is not None else m
|
||||
msp_res = TestNet()(a, b, m, maxiter, tol)
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
|
||||
assert osp_res[1] == msp_res[1].asnumpy().item()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5)])
|
||||
@pytest.mark.parametrize('shape', [(7, 7)])
|
||||
@pytest.mark.parametrize('preconditioner', [None, 'identity', 'random'])
|
||||
@pytest.mark.parametrize('maxiter', [3])
|
||||
def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases of CSRTensor for cg
|
||||
Expectation: the result match scipy.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
def construct(self, a, b, m, maxiter, tol):
|
||||
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
onp.random.seed(0)
|
||||
|
||||
# scipy
|
||||
a = create_sym_pos_sparse_matrix(shape, dtype)
|
||||
b = onp.random.random(shape[:1]).astype(dtype)
|
||||
m = _fetch_preconditioner(preconditioner, a)
|
||||
osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
# mindspore
|
||||
a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape)
|
||||
b = Tensor(b)
|
||||
m = Tensor(m) if m is not None else m
|
||||
msp_res = TestNet()(a, b, m, maxiter, tol)
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
|
||||
assert osp_res[1] == msp_res[1].asnumpy().item()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import List
|
|||
from functools import cmp_to_key
|
||||
|
||||
import numpy as onp
|
||||
import scipy.sparse.linalg
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.numpy as mnp
|
||||
|
@ -98,6 +99,18 @@ def create_sym_pos_matrix(shape, dtype):
|
|||
return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype)
|
||||
|
||||
|
||||
def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32):
|
||||
if len(shape) != 2 or shape[0] != shape[1]:
|
||||
raise ValueError(
|
||||
'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
|
||||
|
||||
n = shape[-1]
|
||||
indptr = onp.arange(n + 1).astype(indice_dtype)
|
||||
indices = onp.arange(n).astype(indice_dtype)
|
||||
values = onp.random.random(n).astype(dtype)
|
||||
return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape)
|
||||
|
||||
|
||||
def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
|
||||
# some utils
|
||||
def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]:
|
||||
|
|
Loading…
Reference in New Issue