add cholesky grad for backend cpu.

This commit is contained in:
z00512249 2022-01-24 14:59:19 +08:00
parent 35efb12908
commit 4fddf02b08
7 changed files with 86 additions and 17 deletions

View File

@ -14,4 +14,4 @@
# ============================================================================
"""Scipy-like interfaces in mindspore."""
from . import linalg, optimize, sparse, ops_wrapper
from . import linalg, optimize, sparse

View File

@ -24,8 +24,8 @@ from .ops import EighNet
from ..ops import operations as P
from ..ops import functional as F
from ..common import dtype as mstype
from .utils import float_types
from .utils_const import _raise_value_error, _type_check
from .utils import float_types, valid_data_types
from .utils_const import _raise_value_error, _raise_type_error, _type_check
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh', 'lu_factor', 'lu']
@ -287,7 +287,10 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
"""
_type_check('overwrite_a', overwrite_a, bool, 'cho_factor')
_type_check('check_finite', check_finite, bool, 'cho_factor')
if F.dtype(a) not in float_types:
a_type = F.dtype(a)
if a_type not in valid_data_types:
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
a_shape = a.shape
if len(a_shape) != 2:
@ -342,9 +345,11 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
"""
_type_check('overwrite_a', overwrite_a, bool, 'cholesky')
_type_check('check_finite', check_finite, bool, 'cholesky')
if F.dtype(a) not in float_types:
a_type = F.dtype(a)
if a_type not in valid_data_types:
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
if a_type not in float_types:
a = F.cast(a, mstype.float64)
a_shape = a.shape
if len(a_shape) != 2:
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must have 2 dimensions.")
@ -391,6 +396,11 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
_type_check('overwrite_b', overwrite_b, bool, 'cho_solve')
_type_check('check_finite', check_finite, bool, 'cho_solve')
(c, lower) = c_and_lower
c_type = F.dtype(c)
if c_type not in valid_data_types:
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
if c_type not in float_types:
c = F.cast(c, mstype.float64)
cholesky_solver_net = CholeskySolver(lower=lower)
x = cholesky_solver_net(c, b)
return x

View File

@ -112,8 +112,8 @@ class Cholesky(PrimitiveWithInfer):
def __init__(self, lower=False, clean=True, split_dim=0):
super().__init__("Cholesky")
self.init_prim_io_names(inputs=['a'], outputs=['l'])
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
self.clean = validator.check_value_type("clean", clean, [bool], self.name)
self.lower = lower
self.add_prim_attr('lower', self.lower)
self.clean = clean
@ -181,7 +181,7 @@ class CholeskySolver(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, lower=False):
super().__init__(name="CholeskySolver")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
self.init_prim_io_names(inputs=['A', 'b'], outputs=['y'])
def __infer__(self, A, b):
@ -419,5 +419,6 @@ class MatrixDiagPartNet(nn.Cell):
def construct(self, a, k, padding_value):
return self.matrix_diag_part(a, k, padding_value)
# pylint: disable=C0413,W0611
from .ops_grad import get_bprpo_eigh, get_bprpo_trsm
from .ops_grad import get_bprop_cholesky, get_bprpo_eigh, get_bprpo_trsm

View File

@ -14,7 +14,8 @@
# ============================================================================
"""Grad implementation of operators for scipy submodule"""
from .. import numpy as mnp
from .ops import Eigh, Eig, SolveTriangular
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
from .ops_wrapper import matrix_set_diag
from .. import dtype as mstype
from ..ops import operations as P
from ..ops import functional as F
@ -23,6 +24,7 @@ from ..ops._grad.grad_base import bprop_getters
_matmul = P.MatMul(False, False)
_real = P.Real()
_conj = P.Conj()
_matrix_band_part = MatrixBandPart()
def _compute_f(w, epsilon=1E-20):
@ -46,6 +48,28 @@ def _matrix_solve(a, b):
return a + b
@bprop_getters.register(Cholesky)
def get_bprop_cholesky(self):
"""Grad definition for `Cholesky` operation."""
inverse = P.MatrixInverse()
matmul = P.MatMul()
def bprop(a, out, dout):
l = out
l_inverse = inverse(l)
dout_middle = matmul(_adjoint(l), dout)
middle_diag = 0.5 * mnp.diag(dout_middle)
dout_middle = matrix_set_diag(dout_middle, middle_diag)
dout_middle = _matrix_band_part(dout_middle, -1, 0)
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
grad_a = mnp.tril(grad_a + _adjoint(grad_a))
middle_diag = 0.5 * mnp.diag(grad_a)
grad_a = matrix_set_diag(grad_a, middle_diag)
return (grad_a,)
return bprop
@bprop_getters.register(Eig)
def get_bprpo_eig(self):
"""Grad definition for `Eig` operation."""

View File

@ -100,7 +100,7 @@ _FLOAT_ONE = _to_tensor(1.0)
_FLOAT_TWO = _to_tensor(2.0, dtype=float)
_BOOL_TRUE = _to_tensor(True)
_BOOL_FALSE = _to_tensor(False)
valid_data_types = (mstype.int32, mstype.int64, mstype.float32, mstype.float64)
float_types = (mstype.float32, mstype.float64)

View File

@ -18,8 +18,42 @@ import numpy as onp
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context, Tensor
from mindspore.scipy.ops import Eigh, SolveTriangular
from tests.st.scipy_st.utils import create_random_rank_matrix, gradient_check
from mindspore.scipy.ops import Eigh, Cholesky, SolveTriangular
from tests.st.scipy_st.utils import create_random_rank_matrix, create_sym_pos_matrix, gradient_check
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(8, 8)])
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-2, 1e-3), (onp.float64, 1e-4, 1e-7)])
def test_cholesky_grad(shape, data_type):
"""
Feature: ALL TO ALL
Description: test cases for grad implementation of cholesky operator in graph mode and pynative mode.
Expectation: the result match gradient checking.
"""
onp.random.seed(0)
context.set_context(mode=context.GRAPH_MODE)
dtype, epsilon, error = data_type
class CholeskyNet(nn.Cell):
def __init__(self):
super(CholeskyNet, self).__init__()
self.mean = ops.ReduceMean()
self.cholesky = Cholesky(lower=True, clean=True)
def construct(self, a):
c = self.cholesky(a)
return self.mean(c)
cholesky_net = CholeskyNet()
a = create_sym_pos_matrix(shape, dtype)
cholesky_net(Tensor(a))
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
context.set_context(mode=context.PYNATIVE_MODE)
cholesky_net(Tensor(a))
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
@pytest.mark.level0

View File

@ -15,7 +15,7 @@
"""st for scipy.ops_wrapper."""
import pytest
import numpy as onp
import mindspore.scipy as msp
import mindspore.scipy.ops_wrapper as ops_wrapper
from mindspore import context, Tensor
from mindspore.scipy.ops import MatrixBandPartNet
from tests.st.scipy_st.utils import match_matrix
@ -310,7 +310,7 @@ def test_matrix_set_diag(data_type):
mask = banded_mat[0] == 0
input_mat = onp.random.randint(10, size=mask.shape)
expected_diag_matrix = input_mat * mask + banded_mat[0]
output = msp.ops_wrapper.matrix_set_diag(
output = ops_wrapper.matrix_set_diag(
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
match_matrix(output, Tensor(expected_diag_matrix))
@ -333,7 +333,7 @@ def test_graph_matrix_set_diag(data_type):
mask = banded_mat[0] == 0
input_mat = onp.random.randint(10, size=mask.shape)
expected_diag_matrix = input_mat * mask + banded_mat[0]
output = msp.ops_wrapper.matrix_set_diag(
output = ops_wrapper.matrix_set_diag(
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
match_matrix(output, Tensor(expected_diag_matrix))