forked from mindspore-Ecosystem/mindspore
!31127 Fix bprop of cg method, support to return CSRTensor.
Merge pull request !31127 from hezhenhao1/fix_cg
This commit is contained in:
commit
3f3c9480dc
|
@ -16,7 +16,7 @@
|
|||
from ... import nn, ms_function
|
||||
from ... import numpy as mnp
|
||||
from ...ops import functional as F
|
||||
from ...common import Tensor, dtype as mstype
|
||||
from ...common import Tensor, CSRTensor, dtype as mstype
|
||||
from ...ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..linalg import solve_triangular
|
||||
from ..linalg import cho_factor, cho_solve
|
||||
|
@ -367,11 +367,17 @@ class CGv2(nn.Cell):
|
|||
return x, F.select(_norm(r) > atol_, k, _INT_ZERO)
|
||||
|
||||
def bprop(self, A, b, x0, tol, atol, maxiter, M, out, dout):
|
||||
"""Grad definition for `CGv2` Cell."""
|
||||
n = b.shape[0]
|
||||
if not isinstance(M, Tensor):
|
||||
if not isinstance(M, (Tensor, CSRTensor)):
|
||||
M = F.eye(n, n, b.dtype)
|
||||
grad_b, _ = self.construct(A, dout[0], x0, tol, atol, maxiter, M)
|
||||
grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
|
||||
if isinstance(A, CSRTensor):
|
||||
grad_a_dense = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
|
||||
values = F.csr_gather(A.indptr, A.indices, grad_a_dense, A.shape)
|
||||
grad_a = CSRTensor(A.indptr, A.indices, values, A.shape)
|
||||
else:
|
||||
grad_a = -1 * F.reshape(grad_b, (n, 1)) * F.reshape(out[0], (1, n))
|
||||
return grad_a, grad_b, zeros_like(x0), zeros_like(tol), zeros_like(atol), zeros_like(maxiter), zeros_like(M)
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import mindspore.nn as nn
|
|||
import mindspore.scipy as msp
|
||||
from mindspore import context
|
||||
from mindspore.common import Tensor
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor
|
||||
from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, to_tensor, to_ndarray
|
||||
|
||||
|
||||
def _fetch_preconditioner(preconditioner, A):
|
||||
|
@ -209,8 +209,8 @@ def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
|||
# Function
|
||||
grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg)
|
||||
grad_a, grad_b = grad_net(a, b)[:2]
|
||||
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
|
||||
|
||||
# Cell
|
||||
class Net(nn.Cell):
|
||||
|
@ -225,8 +225,8 @@ def test_cg_grad(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
|||
|
||||
grad_net = ops.GradOperation(get_all=True)(Net())
|
||||
grad_a, grad_b = grad_net(a, b)[:2]
|
||||
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -279,8 +279,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
|||
# Function
|
||||
grad_net = ops.GradOperation(get_all=True)(msp.sparse.linalg.cg)
|
||||
grad_a, grad_b = grad_net(a, b)[:2]
|
||||
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
|
||||
|
||||
# Cell
|
||||
class Net(nn.Cell):
|
||||
|
@ -295,8 +295,8 @@ def test_cg_grad_pynative(tensor_type, dtype, tol, a, b, grad_a, grad_b):
|
|||
|
||||
grad_net = ops.GradOperation(get_all=True)(Net())
|
||||
grad_a, grad_b = grad_net(a, b)[:2]
|
||||
onp.testing.assert_allclose(expect_grad_a, grad_a.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, grad_b.asnumpy(), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_a, to_ndarray(grad_a), **kw)
|
||||
onp.testing.assert_allclose(expect_grad_b, to_ndarray(grad_b), **kw)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import List
|
|||
from functools import cmp_to_key
|
||||
|
||||
import numpy as onp
|
||||
import scipy as osp
|
||||
import scipy.sparse.linalg
|
||||
from mindspore import Tensor, CSRTensor
|
||||
import mindspore.ops as ops
|
||||
import mindspore.numpy as mnp
|
||||
|
@ -38,14 +38,14 @@ def to_tensor(obj, dtype=None, indice_dtype=onp.int32):
|
|||
if tensor_type == "Tensor":
|
||||
obj = onp.array(obj)
|
||||
elif tensor_type == "CSRTensor":
|
||||
obj = osp.sparse.csr_matrix(obj)
|
||||
obj = scipy.sparse.csr_matrix(obj)
|
||||
|
||||
if dtype is None:
|
||||
dtype = obj.dtype
|
||||
|
||||
if isinstance(obj, onp.ndarray):
|
||||
obj = Tensor(obj.astype(dtype))
|
||||
elif isinstance(obj, osp.sparse.csr_matrix):
|
||||
elif isinstance(obj, scipy.sparse.csr_matrix):
|
||||
obj = CSRTensor(indptr=Tensor(obj.indptr.astype(indice_dtype)),
|
||||
indices=Tensor(obj.indices.astype(indice_dtype)),
|
||||
values=Tensor(obj.data.astype(dtype)),
|
||||
|
@ -54,6 +54,19 @@ def to_tensor(obj, dtype=None, indice_dtype=onp.int32):
|
|||
return obj
|
||||
|
||||
|
||||
def to_ndarray(obj, dtype=None):
|
||||
if isinstance(obj, Tensor):
|
||||
obj = obj.asnumpy()
|
||||
elif isinstance(obj, CSRTensor):
|
||||
obj = scipy.sparse.csr_matrix((obj.values.asnumpy(), obj.indices.asnumpy(), obj.indptr.asnumpy()),
|
||||
shape=obj.shape)
|
||||
obj = obj.toarray()
|
||||
|
||||
if dtype is not None:
|
||||
obj = obj.astype(dtype)
|
||||
return obj
|
||||
|
||||
|
||||
def match_array(actual, expected, error=0, err_msg=''):
|
||||
if isinstance(actual, int):
|
||||
actual = onp.asarray(actual)
|
||||
|
|
Loading…
Reference in New Issue