forked from mindspore-Ecosystem/mindspore
!30847 Add grad implementation and test cases for cg method.
Merge pull request !30847 from hezhenhao1/fix_cg
This commit is contained in:
commit
409d0f6715
|
@ -16,7 +16,8 @@
|
|||
from ... import nn, ms_function
|
||||
from ... import numpy as mnp
|
||||
from ...ops import functional as F
|
||||
from ...common import dtype as mstype
|
||||
from ...common import Tensor, dtype as mstype
|
||||
from ...ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..linalg import solve_triangular
|
||||
from ..linalg import cho_factor, cho_solve
|
||||
from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \
|
||||
|
@ -322,6 +323,10 @@ class CG(nn.Cell):
|
|||
|
||||
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
|
||||
|
||||
def bprop(self, b, x0, tol, atol, maxiter, out, dout):
|
||||
grad_b, _ = self.construct(dout[0], x0, tol, atol, maxiter)
|
||||
return grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter)
|
||||
|
||||
|
||||
class CGv2(nn.Cell):
|
||||
"""
|
||||
|
@ -331,7 +336,7 @@ class CGv2(nn.Cell):
|
|||
def __init__(self):
|
||||
super(CGv2, self).__init__()
|
||||
|
||||
def construct(self, A, M, b, x0, tol, atol, maxiter):
|
||||
def construct(self, A, b, x0, tol, atol, maxiter, M):
|
||||
# Constant tensor which avoids loop unrolling
|
||||
_INT_ZERO = _to_tensor(0)
|
||||
|
||||
|
@ -361,6 +366,14 @@ class CGv2(nn.Cell):
|
|||
|
||||
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
|
||||
|
||||
def bprop(self, A, b, x0, tol, atol, maxiter, M, out, dout):
|
||||
if not isinstance(M, Tensor):
|
||||
M = lambda x: x
|
||||
n = b.shape[0]
|
||||
grad_b, _ = self.construct(A, dout[0], x0, tol, atol, maxiter, M)
|
||||
grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
|
||||
return grad_a, grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter), zeros_like(M)
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None):
|
||||
"""Use Conjugate Gradient iteration to solve the linear system:
|
||||
|
@ -455,7 +468,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None
|
|||
if not _nullable_const(A):
|
||||
x, info = CG(A, M)(b, x0, tol, atol, maxiter)
|
||||
else:
|
||||
x, info = CGv2()(A, M, b, x0, tol, atol, maxiter)
|
||||
x, info = CGv2()(A, b, x0, tol, atol, maxiter, M)
|
||||
return x, info
|
||||
|
||||
|
||||
|
|
|
@ -17,12 +17,13 @@ import pytest
|
|||
import numpy as onp
|
||||
import scipy as osp
|
||||
import scipy.sparse.linalg
|
||||
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
import mindspore.scipy as msp
|
||||
from mindspore import context
|
||||
from mindspore.common import Tensor, CSRTensor
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix, \
|
||||
to_tensor
|
||||
|
||||
|
||||
def _fetch_preconditioner(preconditioner, A):
|
||||
|
@ -128,7 +129,7 @@ def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
|
|||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, b, m, maxiter, tol):
|
||||
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
|
@ -141,7 +142,7 @@ def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter):
|
|||
a = Tensor(a)
|
||||
b = Tensor(b)
|
||||
m = Tensor(m) if m is not None else m
|
||||
msp_res = TestNet()(a, b, m, maxiter, tol)
|
||||
msp_res = Net()(a, b, m, maxiter, tol)
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
|
||||
|
@ -163,7 +164,7 @@ def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
|
|||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, b, m, maxiter, tol):
|
||||
return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol)
|
||||
|
||||
|
@ -179,13 +180,83 @@ def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter):
|
|||
a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape)
|
||||
b = Tensor(b)
|
||||
m = Tensor(m) if m is not None else m
|
||||
msp_res = TestNet()(a, b, m, maxiter, tol)
|
||||
msp_res = Net()(a, b, m, maxiter, tol)
|
||||
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw)
|
||||
assert osp_res[1] == msp_res[1].asnumpy().item()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('tensor_type, dtype, tol', [('CSRTensor', onp.float32, 1e-5), ('CSRTensor', onp.float32, 1e-5),
|
||||
('Tensor', onp.float64, 1e-8)])
|
||||
@pytest.mark.parametrize('a, b, grad_a, grad_b', [
|
||||
([[1.96822833, 0.82204467, 1.03749232, 0.88915326, 0.44986806, 1.11167143],
|
||||
[0.82204467, 2.25216591, 1.40235719, 0.70838919, 0.81377919, 1.06000368],
|
||||
[1.03749232, 1.40235719, 2.90618746, 0.7126087, 0.81029544, 1.28673025],
|
||||
[0.88915326, 0.70838919, 0.7126087, 2.17515263, 0.40443765, 1.02082996],
|
||||
[0.44986806, 0.81377919, 0.81029544, 0.40443765, 1.60570668, 0.62292701],
|
||||
[1.11167143, 1.06000368, 1.28673025, 1.02082996, 0.62292701, 2.30795277]],
|
||||
[0.79363745, 0.58000418, 0.1622986, 0.70075235, 0.96455108, 0.50000836],
|
||||
[[-0.07867674, -0.01521201, 0.06394698, -0.03854052, -0.13523701, 0.01326866],
|
||||
[-0.03508505, -0.00678363, 0.02851647, -0.01718673, -0.06030749, 0.00591702],
|
||||
[-0.00586019, -0.00113306, 0.00476305, -0.00287067, -0.01007304, 0.00098831],
|
||||
[-0.07704304, -0.01489613, 0.06261914, -0.03774023, -0.13242886, 0.01299314],
|
||||
[-0.14497008, -0.02802971, 0.11782896, -0.07101491, -0.24918826, 0.02444888],
|
||||
[-0.01868565, -0.00361284, 0.01518735, -0.00915334, -0.03211867, 0.00315129]],
|
||||
[0.22853142, 0.10191113, 0.01702201, 0.22378603, 0.42109291, 0.054276]),
|
||||
([[1.85910724, 0.73233206, 0.65960803, 1.03821349, 0.55277616],
|
||||
[0.73233206, 1.69548841, 0.59992146, 1.01518264, 0.50824059],
|
||||
[0.65960803, 0.59992146, 1.98169091, 1.45565213, 0.47901749],
|
||||
[1.03821349, 1.01518264, 1.45565213, 3.3133049, 0.75598147],
|
||||
[0.55277616, 0.50824059, 0.47901749, 0.75598147, 1.46831254]],
|
||||
[0.59674531, 0.226012, 0.10694568, 0.22030621, 0.34982629],
|
||||
[[-0.07498642, 0.00167461, 0.01353184, 0.01008293, -0.03770084],
|
||||
[-0.09940184, 0.00221986, 0.01793778, 0.01336592, -0.04997616],
|
||||
[-0.09572781, 0.00213781, 0.01727477, 0.01287189, -0.04812897],
|
||||
[0.03135044, -0.00070012, -0.00565741, -0.00421549, 0.01576203],
|
||||
[-0.14053766, 0.00313851, 0.02536103, 0.01889718, -0.07065797]],
|
||||
[0.23398106, 0.31016481, 0.29870068, -0.09782316, 0.43852141]),
|
||||
])
|
||||
def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for grad implementation of cg in graph mode(currently)
|
||||
Expectation: the result match expectation
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
a = to_tensor((a, tensor_type), dtype)
|
||||
b = Tensor(onp.array(b, dtype=dtype))
|
||||
expect_grad_a = onp.array(grad_a, dtype=dtype)
|
||||
expect_grad_b = onp.array(grad_b, dtype=dtype)
|
||||
kw = {"atol": tol, "rtol": tol}
|
||||
|
||||
# Function
|
||||
grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg)
|
||||
grad_a, grad_b = grad_net(a, b)[:2]
|
||||
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
|
||||
|
||||
# Cell
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.sum = ops.ReduceSum()
|
||||
self.cg = msp.sparse.linalg.cg
|
||||
|
||||
def construct(self, a, b):
|
||||
x, _ = self.cg(a, b)
|
||||
return self.sum(x)
|
||||
|
||||
grad_net = ops.GradOperation(get_all=True)(Net())
|
||||
grad_a, grad_b = grad_net(a, b)[:2]
|
||||
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
|
@ -14,25 +14,43 @@
|
|||
# ============================================================================
|
||||
"""utility functions for mindspore.scipy st tests"""
|
||||
from typing import List
|
||||
from functools import cmp_to_key
|
||||
from functools import cmp_to_key, partial
|
||||
|
||||
import numpy as onp
|
||||
import scipy as osp
|
||||
import scipy.sparse.linalg
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, CSRTensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def to_tensor(obj, dtype=None):
|
||||
"""
|
||||
This function is used to initialize Tensor or CSRTensor.
|
||||
'obj' can be three type:
|
||||
1. tuple or list
|
||||
Must be the format: (list, str), and str should be 'Tensor' or 'CSRTensor'.
|
||||
2. numpy.ndarray
|
||||
3. scipy.sparse.csr_matrix
|
||||
"""
|
||||
if isinstance(obj, (tuple, list)):
|
||||
obj, tensor_type = obj
|
||||
if tensor_type == "Tensor":
|
||||
obj = onp.array(obj)
|
||||
elif tensor_type == "CSRTensor":
|
||||
obj = osp.sparse.csr_matrix(obj)
|
||||
|
||||
if dtype is None:
|
||||
res = Tensor(obj)
|
||||
if res.dtype == mnp.float64:
|
||||
res = res.astype(mnp.float32)
|
||||
if res.dtype == mnp.int64:
|
||||
res = res.astype(mnp.int32)
|
||||
dtype = obj.dtype
|
||||
|
||||
if isinstance(obj, onp.ndarray):
|
||||
tensor_fn = partial(Tensor, input_data=obj.astype(dtype))
|
||||
else:
|
||||
res = Tensor(obj, dtype)
|
||||
tensor_fn = partial(CSRTensor, indptr=Tensor(obj.indptr), indices=Tensor(obj.indices),
|
||||
values=Tensor(obj.data.astype(dtype)), shape=obj.shape)
|
||||
|
||||
res = tensor_fn()
|
||||
return res
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue