!31127 Fix bprop of cg method, support to return CSRTensor.

Merge pull request !31127 from hezhenhao1/fix_cg
This commit is contained in:
i-robot 2022-03-11 01:03:59 +00:00 committed by Gitee
commit 3f3c9480dc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 34 additions and 15 deletions

View File

@ -16,7 +16,7 @@
from ... import nn, ms_function
from ... import numpy as mnp
from ...ops import functional as F
from ...common import Tensor, dtype as mstype
from ...common import Tensor, CSRTensor, dtype as mstype
from ...ops.composite.multitype_ops.zeros_like_impl import zeros_like
from ..linalg import solve_triangular
from ..linalg import cho_factor, cho_solve
@ -367,11 +367,17 @@ class CGv2(nn.Cell):
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
def bprop(self, A, b, x0, tol, atol, maxiter, M, out, dout):
"""Grad definition for `CGv2` Cell."""
n = b.shape[0]
if not isinstance(M, Tensor):
if not isinstance(M, (Tensor, CSRTensor)):
M = F.eye(n, n, b.dtype)
grad_b, _ = self.construct(A, dout[0], x0, tol, atol, maxiter, M)
grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
if isinstance(A, CSRTensor):
grad_a_dense = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
values = F.csr_gather(A.indptr, A.indices, grad_a_dense, A.shape)
grad_a = CSRTensor(A.indptr, A.indices, values, A.shape)
else:
grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
return grad_a, grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter), zeros_like(M)

View File

@ -22,7 +22,7 @@ import mindspore.nn as nn
import mindspore.scipy as msp
from mindspore import context
from mindspore.common import Tensor
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor, to_ndarray
def _fetch_preconditioner(preconditioner, A):
@ -209,8 +209,8 @@ def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b):
# Function
grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg)
grad_a, grad_b = grad_net(a, b)[:2]
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
# Cell
class Net(nn.Cell):
@ -225,8 +225,8 @@ def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b):
grad_net = ops.GradOperation(get_all=True)(Net())
grad_a, grad_b = grad_net(a, b)[:2]
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
@pytest.mark.level0
@ -279,8 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
# Function
grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg)
grad_a, grad_b = grad_net(a, b)[:2]
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
# Cell
class Net(nn.Cell):
@ -295,8 +295,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
grad_net = ops.GradOperation(get_all=True)(Net())
grad_a, grad_b = grad_net(a, b)[:2]
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
@pytest.mark.level0

View File

@ -17,7 +17,7 @@ from typing import List
from functools import cmp_to_key
import numpy as onp
import scipy as osp
import scipy.sparse.linalg
from mindspore import Tensor, CSRTensor
import mindspore.ops as ops
import mindspore.numpy as mnp
@ -38,14 +38,14 @@ def to_tensor(obj, dtype=None, indice_dtype=onp.int32):
if tensor_type == "Tensor":
obj = onp.array(obj)
elif tensor_type == "CSRTensor":
obj = osp.sparse.csr_matrix(obj)
obj = scipy.sparse.csr_matrix(obj)
if dtype is None:
dtype = obj.dtype
if isinstance(obj, onp.ndarray):
obj = Tensor(obj.astype(dtype))
elif isinstance(obj, osp.sparse.csr_matrix):
elif isinstance(obj, scipy.sparse.csr_matrix):
obj = CSRTensor(indptr=Tensor(obj.indptr.astype(indice_dtype)),
indices=Tensor(obj.indices.astype(indice_dtype)),
values=Tensor(obj.data.astype(dtype)),
@ -54,6 +54,19 @@ def to_tensor(obj, dtype=None, indice_dtype=onp.int32):
return obj
def to_ndarray(obj, dtype=None):
if isinstance(obj, Tensor):
obj = obj.asnumpy()
elif isinstance(obj, CSRTensor):
obj = scipy.sparse.csr_matrix((obj.values.asnumpy(), obj.indices.asnumpy(), obj.indptr.asnumpy()),
shape=obj.shape)
obj = obj.toarray()
if dtype is not None:
obj = obj.astype(dtype)
return obj
def match_array(actual, expected, error=0, err_msg=''):
if isinstance(actual, int):
actual = onp.asarray(actual)