add cholesky grad for backend cpu.
This commit is contained in:
parent
35efb12908
commit
4fddf02b08
|
@ -14,4 +14,4 @@
|
|||
# ============================================================================
|
||||
"""Scipy-like interfaces in mindspore."""
|
||||
|
||||
from . import linalg, optimize, sparse, ops_wrapper
|
||||
from . import linalg, optimize, sparse
|
||||
|
|
|
@ -24,8 +24,8 @@ from .ops import EighNet
|
|||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
from ..common import dtype as mstype
|
||||
from .utils import float_types
|
||||
from .utils_const import _raise_value_error, _type_check
|
||||
from .utils import float_types, valid_data_types
|
||||
from .utils_const import _raise_value_error, _raise_type_error, _type_check
|
||||
|
||||
__all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh', 'lu_factor', 'lu']
|
||||
|
||||
|
@ -287,7 +287,10 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'cho_factor')
|
||||
_type_check('check_finite', check_finite, bool, 'cho_factor')
|
||||
if F.dtype(a) not in float_types:
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
a_shape = a.shape
|
||||
if len(a_shape) != 2:
|
||||
|
@ -342,9 +345,11 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|||
"""
|
||||
_type_check('overwrite_a', overwrite_a, bool, 'cholesky')
|
||||
_type_check('check_finite', check_finite, bool, 'cholesky')
|
||||
if F.dtype(a) not in float_types:
|
||||
a_type = F.dtype(a)
|
||||
if a_type not in valid_data_types:
|
||||
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
|
||||
if a_type not in float_types:
|
||||
a = F.cast(a, mstype.float64)
|
||||
|
||||
a_shape = a.shape
|
||||
if len(a_shape) != 2:
|
||||
_raise_value_error("input a to mindspore.scipy.linalg.cholesky must have 2 dimensions.")
|
||||
|
@ -391,6 +396,11 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|||
_type_check('overwrite_b', overwrite_b, bool, 'cho_solve')
|
||||
_type_check('check_finite', check_finite, bool, 'cho_solve')
|
||||
(c, lower) = c_and_lower
|
||||
c_type = F.dtype(c)
|
||||
if c_type not in valid_data_types:
|
||||
_raise_type_error("mindspore.scipy.linalg.cholesky only support int32, int64, float32, float64.")
|
||||
if c_type not in float_types:
|
||||
c = F.cast(c, mstype.float64)
|
||||
cholesky_solver_net = CholeskySolver(lower=lower)
|
||||
x = cholesky_solver_net(c, b)
|
||||
return x
|
||||
|
|
|
@ -112,8 +112,8 @@ class Cholesky(PrimitiveWithInfer):
|
|||
def __init__(self, lower=False, clean=True, split_dim=0):
|
||||
super().__init__("Cholesky")
|
||||
self.init_prim_io_names(inputs=['a'], outputs=['l'])
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.name)
|
||||
self.lower = lower
|
||||
self.add_prim_attr('lower', self.lower)
|
||||
self.clean = clean
|
||||
|
@ -181,7 +181,7 @@ class CholeskySolver(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, lower=False):
|
||||
super().__init__(name="CholeskySolver")
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
|
||||
self.lower = validator.check_value_type("lower", lower, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['A', 'b'], outputs=['y'])
|
||||
|
||||
def __infer__(self, A, b):
|
||||
|
@ -419,5 +419,6 @@ class MatrixDiagPartNet(nn.Cell):
|
|||
def construct(self, a, k, padding_value):
|
||||
return self.matrix_diag_part(a, k, padding_value)
|
||||
|
||||
|
||||
# pylint: disable=C0413,W0611
|
||||
from .ops_grad import get_bprpo_eigh, get_bprpo_trsm
|
||||
from .ops_grad import get_bprop_cholesky, get_bprpo_eigh, get_bprpo_trsm
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# ============================================================================
|
||||
"""Grad implementation of operators for scipy submodule"""
|
||||
from .. import numpy as mnp
|
||||
from .ops import Eigh, Eig, SolveTriangular
|
||||
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
|
||||
from .ops_wrapper import matrix_set_diag
|
||||
from .. import dtype as mstype
|
||||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
|
@ -23,6 +24,7 @@ from ..ops._grad.grad_base import bprop_getters
|
|||
_matmul = P.MatMul(False, False)
|
||||
_real = P.Real()
|
||||
_conj = P.Conj()
|
||||
_matrix_band_part = MatrixBandPart()
|
||||
|
||||
|
||||
def _compute_f(w, epsilon=1E-20):
|
||||
|
@ -46,6 +48,28 @@ def _matrix_solve(a, b):
|
|||
return a + b
|
||||
|
||||
|
||||
@bprop_getters.register(Cholesky)
|
||||
def get_bprop_cholesky(self):
|
||||
"""Grad definition for `Cholesky` operation."""
|
||||
inverse = P.MatrixInverse()
|
||||
matmul = P.MatMul()
|
||||
|
||||
def bprop(a, out, dout):
|
||||
l = out
|
||||
l_inverse = inverse(l)
|
||||
dout_middle = matmul(_adjoint(l), dout)
|
||||
middle_diag = 0.5 * mnp.diag(dout_middle)
|
||||
dout_middle = matrix_set_diag(dout_middle, middle_diag)
|
||||
dout_middle = _matrix_band_part(dout_middle, -1, 0)
|
||||
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
|
||||
grad_a = mnp.tril(grad_a + _adjoint(grad_a))
|
||||
middle_diag = 0.5 * mnp.diag(grad_a)
|
||||
grad_a = matrix_set_diag(grad_a, middle_diag)
|
||||
return (grad_a,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Eig)
|
||||
def get_bprpo_eig(self):
|
||||
"""Grad definition for `Eig` operation."""
|
||||
|
|
|
@ -100,7 +100,7 @@ _FLOAT_ONE = _to_tensor(1.0)
|
|||
_FLOAT_TWO = _to_tensor(2.0, dtype=float)
|
||||
_BOOL_TRUE = _to_tensor(True)
|
||||
_BOOL_FALSE = _to_tensor(False)
|
||||
|
||||
valid_data_types = (mstype.int32, mstype.int64, mstype.float32, mstype.float64)
|
||||
float_types = (mstype.float32, mstype.float64)
|
||||
|
||||
|
||||
|
|
|
@ -18,8 +18,42 @@ import numpy as onp
|
|||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.scipy.ops import Eigh, SolveTriangular
|
||||
from tests.st.scipy_st.utils import create_random_rank_matrix, gradient_check
|
||||
from mindspore.scipy.ops import Eigh, Cholesky, SolveTriangular
|
||||
from tests.st.scipy_st.utils import create_random_rank_matrix, create_sym_pos_matrix, gradient_check
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(8, 8)])
|
||||
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-2, 1e-3), (onp.float64, 1e-4, 1e-7)])
|
||||
def test_cholesky_grad(shape, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for grad implementation of cholesky operator in graph mode and pynative mode.
|
||||
Expectation: the result match gradient checking.
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
dtype, epsilon, error = data_type
|
||||
|
||||
class CholeskyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CholeskyNet, self).__init__()
|
||||
self.mean = ops.ReduceMean()
|
||||
self.cholesky = Cholesky(lower=True, clean=True)
|
||||
|
||||
def construct(self, a):
|
||||
c = self.cholesky(a)
|
||||
return self.mean(c)
|
||||
|
||||
cholesky_net = CholeskyNet()
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
cholesky_net(Tensor(a))
|
||||
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
cholesky_net(Tensor(a))
|
||||
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""st for scipy.ops_wrapper."""
|
||||
import pytest
|
||||
import numpy as onp
|
||||
import mindspore.scipy as msp
|
||||
import mindspore.scipy.ops_wrapper as ops_wrapper
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.scipy.ops import MatrixBandPartNet
|
||||
from tests.st.scipy_st.utils import match_matrix
|
||||
|
@ -310,7 +310,7 @@ def test_matrix_set_diag(data_type):
|
|||
mask = banded_mat[0] == 0
|
||||
input_mat = onp.random.randint(10, size=mask.shape)
|
||||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = msp.ops_wrapper.matrix_set_diag(
|
||||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_matrix(output, Tensor(expected_diag_matrix))
|
||||
|
||||
|
@ -333,7 +333,7 @@ def test_graph_matrix_set_diag(data_type):
|
|||
mask = banded_mat[0] == 0
|
||||
input_mat = onp.random.randint(10, size=mask.shape)
|
||||
expected_diag_matrix = input_mat * mask + banded_mat[0]
|
||||
output = msp.ops_wrapper.matrix_set_diag(
|
||||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_matrix(output, Tensor(expected_diag_matrix))
|
||||
|
||||
|
|
Loading…
Reference in New Issue