!30847 Add grad implementation and test cases for cg method.

Merge pull request !30847 from hezhenhao1/fix_cg
This commit is contained in:
i-robot 2022-03-08 06:40:01 +00:00 committed by Gitee
commit 409d0f6715
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 119 additions and 17 deletions

View File

@ -16,7 +16,8 @@
from ... import nn, ms_function
from ... import numpy as mnp
from ...ops import functional as F
from ...common import dtype as mstype
from ...common import Tensor, dtype as mstype
from ...ops.composite.multitype_ops.zeros_like_impl import zeros_like
from ..linalg import solve_triangular
from ..linalg import cho_factor, cho_solve
from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \
@ -322,6 +323,10 @@ class CG(nn.Cell):
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
def bprop(self, b, x0, tol, atol, maxiter, out, dout):
grad_b, _ = self.construct(dout[0], x0, tol, atol, maxiter)
return grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter)
class CGv2(nn.Cell):
"""
@ -331,7 +336,7 @@ class CGv2(nn.Cell):
def __init__(self):
super(CGv2, self).__init__()
def construct(self, A, M, b, x0, tol, atol, maxiter):
def construct(self, A, b, x0, tol, atol, maxiter, M):
# Constant tensor which avoids loop unrolling
_INT_ZERO = _to_tensor(0)
@ -361,6 +366,14 @@ class CGv2(nn.Cell):
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
def bprop(self, A, b, x0, tol, atol, maxiter, M, out, dout):
if not isinstance(M, Tensor):
M = lambda x: x
n = b.shape[0]
grad_b, _ = self.construct(A, dout[0], x0, tol, atol, maxiter, M)
grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
return grad_a, grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter), zeros_like(M)
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None):
"""Use Conjugate Gradient iteration to solve the linear system:
@ -455,7 +468,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
if not _nullable_const(A):
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
else:
x, info = CGv2()(A, M, b, x0, tol, atol, maxiter)
x, info = CGv2()(A, b, x0, tol, atol, maxiter, M)
return x, info

View File

@ -17,12 +17,13 @@ import pytest
import numpy as onp
import scipy as osp
import scipy.sparse.linalg
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor, CSRTensor
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix, \
to_tensor
def _fetch_preconditioner(preconditioner, A):
@ -128,7 +129,7 @@ def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
"""
context.set_context(mode=context.GRAPH_MODE)
class TestNet(nn.Cell):
class Net(nn.Cell):
def construct(self, a, b, m, maxiter, tol):
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
@ -141,7 +142,7 @@ def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
a = Tensor(a)
b = Tensor(b)
m = Tensor(m) if m is not None else m
msp_res = TestNet()(a, b, m, maxiter, tol)
msp_res = Net()(a, b, m, maxiter, tol)
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
@ -163,7 +164,7 @@ def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
"""
context.set_context(mode=context.GRAPH_MODE)
class TestNet(nn.Cell):
class Net(nn.Cell):
def construct(self, a, b, m, maxiter, tol):
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
@ -179,13 +180,83 @@ def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape)
b = Tensor(b)
m = Tensor(m) if m is not None else m
msp_res = TestNet()(a, b, m, maxiter, tol)
msp_res = Net()(a, b, m, maxiter, tol)
kw = {"atol": tol, "rtol": tol}
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
assert osp_res[1] == msp_res[1].asnumpy().item()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('tensor_type, dtype, tol', [('CSRTensor', onp.float32, 1e-5), ('CSRTensor', onp.float32, 1e-5),
('Tensor', onp.float64, 1e-8)])
@pytest.mark.parametrize('a, b, grad_a, grad_b', [
([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143],
[0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368],
[1.03749232, 1.40235719, 2.90618746, 0.7126087, 0.81029544, 1.28673025],
[0.88915326, 0.70838919, 0.7126087, 2.17515263, 0.40443765, 1.02082996],
[0.44986806, 0.81377919, 0.81029544, 0.40443765, 1.60570668, 0.62292701],
[1.11167143, 1.06000368, 1.28673025, 1.02082996, 0.62292701, 2.30795277]],
[0.79363745, 0.58000418, 0.1622986, 0.70075235, 0.96455108, 0.50000836],
[[-0.07867674, -0.01521201, 0.06394698, -0.03854052, -0.13523701, 0.01326866],
[-0.03508505, -0.00678363, 0.02851647, -0.01718673, -0.06030749, 0.00591702],
[-0.00586019, -0.00113306, 0.00476305, -0.00287067, -0.01007304, 0.00098831],
[-0.07704304, -0.01489613, 0.06261914, -0.03774023, -0.13242886, 0.01299314],
[-0.14497008, -0.02802971, 0.11782896, -0.07101491, -0.24918826, 0.02444888],
[-0.01868565, -0.00361284, 0.01518735, -0.00915334, -0.03211867, 0.00315129]],
[0.22853142, 0.10191113, 0.01702201, 0.22378603, 0.42109291, 0.054276]),
([[1.85910724, 0.73233206, 0.65960803, 1.03821349, 0.55277616],
[0.73233206, 1.69548841, 0.59992146, 1.01518264, 0.50824059],
[0.65960803, 0.59992146, 1.98169091, 1.45565213, 0.47901749],
[1.03821349, 1.01518264, 1.45565213, 3.3133049, 0.75598147],
[0.55277616, 0.50824059, 0.47901749, 0.75598147, 1.46831254]],
[0.59674531, 0.226012, 0.10694568, 0.22030621, 0.34982629],
[[-0.07498642, 0.00167461, 0.01353184, 0.01008293, -0.03770084],
[-0.09940184, 0.00221986, 0.01793778, 0.01336592, -0.04997616],
[-0.09572781, 0.00213781, 0.01727477, 0.01287189, -0.04812897],
[0.03135044, -0.00070012, -0.00565741, -0.00421549, 0.01576203],
[-0.14053766, 0.00313851, 0.02536103, 0.01889718, -0.07065797]],
[0.23398106, 0.31016481, 0.29870068, -0.09782316, 0.43852141]),
])
def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b):
"""
Feature: ALL TO ALL
Description: test cases for grad implementation of cg in graph mode(currently)
Expectation: the result match expectation
"""
context.set_context(mode=context.GRAPH_MODE)
a = to_tensor((a, tensor_type), dtype)
b = Tensor(onp.array(b, dtype=dtype))
expect_grad_a = onp.array(grad_a, dtype=dtype)
expect_grad_b = onp.array(grad_b, dtype=dtype)
kw = {"atol": tol, "rtol": tol}
# Function
grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg)
grad_a, grad_b = grad_net(a, b)[:2]
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
# Cell
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sum = ops.ReduceSum()
self.cg = msp.sparse.linalg.cg
def construct(self, a, b):
x, _ = self.cg(a, b)
return self.sum(x)
grad_net = ops.GradOperation(get_all=True)(Net())
grad_a, grad_b = grad_net(a, b)[:2]
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training

View File

@ -14,25 +14,43 @@
# ============================================================================
"""utility functions for mindspore.scipy st tests"""
from typing import List
from functools import cmp_to_key
from functools import cmp_to_key, partial
import numpy as onp
import scipy as osp
import scipy.sparse.linalg
from mindspore import Tensor
from mindspore import Tensor, CSRTensor
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common import dtype as mstype
def to_tensor(obj, dtype=None):
"""
This function is used to initialize Tensor or CSRTensor.
'obj' can be three type:
1. tuple or list
Must be the format: (list, str), and str should be 'Tensor' or 'CSRTensor'.
2. numpy.ndarray
3. scipy.sparse.csr_matrix
"""
if isinstance(obj, (tuple, list)):
obj, tensor_type = obj
if tensor_type == "Tensor":
obj = onp.array(obj)
elif tensor_type == "CSRTensor":
obj = osp.sparse.csr_matrix(obj)
if dtype is None:
res = Tensor(obj)
if res.dtype == mnp.float64:
res = res.astype(mnp.float32)
if res.dtype == mnp.int64:
res = res.astype(mnp.int32)
dtype = obj.dtype
if isinstance(obj, onp.ndarray):
tensor_fn = partial(Tensor, input_data=obj.astype(dtype))
else:
res = Tensor(obj, dtype)
tensor_fn = partial(CSRTensor, indptr=Tensor(obj.indptr), indices=Tensor(obj.indices),
values=Tensor(obj.data.astype(dtype)), shape=obj.shape)
res = tensor_fn()
return res