diff --git a/mindspore/python/mindspore/scipy/sparse/linalg.py b/mindspore/python/mindspore/scipy/sparse/linalg.py index d73de043107..5bc34272543 100644 --- a/mindspore/python/mindspore/scipy/sparse/linalg.py +++ b/mindspore/python/mindspore/scipy/sparse/linalg.py @@ -16,7 +16,7 @@ from ... import nn, ms_function from ... import numpy as mnp from ...ops import functional as F -from ...common import Tensor, dtype as mstype +from ...common import Tensor, CSRTensor, dtype as mstype from ...ops.composite.multitype_ops.zeros_like_impl import zeros_like from ..linalg import solve_triangular from ..linalg import cho_factor, cho_solve @@ -367,11 +367,17 @@ class CGv2(nn.Cell): return x, F.select(_norm(r) > atol_, k, _INT_ZERO) def bprop(self, A, b, x0, tol, atol, maxiter, M, out, dout): + """Grad definition for `CGv2` Cell.""" n = b.shape[0] - if not isinstance(M, Tensor): + if not isinstance(M, (Tensor, CSRTensor)): M = F.eye(n, n, b.dtype) grad_b, _ = self.construct(A, dout[0], x0, tol, atol, maxiter, M) - grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n)) + if isinstance(A, CSRTensor): + grad_a_dense = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n)) + values = F.csr_gather(A.indptr, A.indices, grad_a_dense, A.shape) + grad_a = CSRTensor(A.indptr, A.indices, values, A.shape) + else: + grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n)) return grad_a, grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter), zeros_like(M) diff --git a/tests/st/scipy_st/sparse/test_linalg.py b/tests/st/scipy_st/sparse/test_linalg.py index 67c72716235..3d8324eb97b 100644 --- a/tests/st/scipy_st/sparse/test_linalg.py +++ b/tests/st/scipy_st/sparse/test_linalg.py @@ -22,7 +22,7 @@ import mindspore.nn as nn import mindspore.scipy as msp from mindspore import context from mindspore.common import Tensor -from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor +from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor, to_ndarray def _fetch_preconditioner(preconditioner, A): @@ -209,8 +209,8 @@ def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b): # Function grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg) grad_a, grad_b = grad_net(a, b)[:2] - onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw) - onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw) + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw) # Cell class Net(nn.Cell): @@ -225,8 +225,8 @@ def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b): grad_net = ops.GradOperation(get_all=True)(Net()) grad_a, grad_b = grad_net(a, b)[:2] - onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw) - onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw) + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw) @pytest.mark.level0 @@ -279,8 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b): # Function grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg) grad_a, grad_b = grad_net(a, b)[:2] - onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw) - onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw) + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw) # Cell class Net(nn.Cell): @@ -295,8 +295,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b): grad_net = ops.GradOperation(get_all=True)(Net()) grad_a, grad_b = grad_net(a, b)[:2] - onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw) - onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw) + onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw) + onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw) @pytest.mark.level0 diff --git a/tests/st/scipy_st/utils.py b/tests/st/scipy_st/utils.py index 024d15819de..79612237748 100644 --- a/tests/st/scipy_st/utils.py +++ b/tests/st/scipy_st/utils.py @@ -17,7 +17,7 @@ from typing import List from functools import cmp_to_key import numpy as onp -import scipy as osp +import scipy.sparse.linalg from mindspore import Tensor, CSRTensor import mindspore.ops as ops import mindspore.numpy as mnp @@ -38,14 +38,14 @@ def to_tensor(obj, dtype=None, indice_dtype=onp.int32): if tensor_type == "Tensor": obj = onp.array(obj) elif tensor_type == "CSRTensor": - obj = osp.sparse.csr_matrix(obj) + obj = scipy.sparse.csr_matrix(obj) if dtype is None: dtype = obj.dtype if isinstance(obj, onp.ndarray): obj = Tensor(obj.astype(dtype)) - elif isinstance(obj, osp.sparse.csr_matrix): + elif isinstance(obj, scipy.sparse.csr_matrix): obj = CSRTensor(indptr=Tensor(obj.indptr.astype(indice_dtype)), indices=Tensor(obj.indices.astype(indice_dtype)), values=Tensor(obj.data.astype(dtype)), @@ -54,6 +54,19 @@ def to_tensor(obj, dtype=None, indice_dtype=onp.int32): return obj +def to_ndarray(obj, dtype=None): + if isinstance(obj, Tensor): + obj = obj.asnumpy() + elif isinstance(obj, CSRTensor): + obj = scipy.sparse.csr_matrix((obj.values.asnumpy(), obj.indices.asnumpy(), obj.indptr.asnumpy()), + shape=obj.shape) + obj = obj.toarray() + + if dtype is not None: + obj = obj.astype(dtype) + return obj + + def match_array(actual, expected, error=0, err_msg=''): if isinstance(actual, int): actual = onp.asarray(actual)