!30967 fix Cholesky bprop implement to support symmetric

Merge pull request !30967 from zhuzhongrui/pub_master2
This commit is contained in:
i-robot 2022-03-09 01:13:50 +00:00 committed by Gitee
commit f7068a46ea
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 11 additions and 11 deletions

View File

@ -74,9 +74,7 @@ def get_bprop_cholesky(self):
dout_middle = matrix_set_diag(dout_middle, middle_diag)
dout_middle = _matrix_band_part(dout_middle, -1, 0)
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
grad_a = _matrix_band_part(grad_a + _adjoint(grad_a), -1, 0)
middle_diag = 0.5 * grad_a.diagonal(0, -2, -1)
grad_a = matrix_set_diag(grad_a, middle_diag)
grad_a = 0.5 * (grad_a + _adjoint(grad_a))
return (grad_a,)
return bprop

View File

@ -53,10 +53,10 @@ def test_cholesky_grad(shape, data_type):
cholesky_net = CholeskyNet()
a = create_sym_pos_matrix(shape, dtype)
cholesky_net(Tensor(a))
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
assert gradient_check(Tensor(a), cholesky_net, epsilon, symmetric=True) < error
context.set_context(mode=context.PYNATIVE_MODE)
cholesky_net(Tensor(a))
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
assert gradient_check(Tensor(a), cholesky_net, epsilon, symmetric=True) < error
@pytest.mark.level0
@ -94,9 +94,9 @@ def test_cho_factor_grad(lower, shape, data_type):
cho_factor_net = ChoFactorNet(lower)
a = create_sym_pos_matrix(shape, dtype)
assert gradient_check(Tensor(a), cho_factor_net, epsilon, _enumerate_fn) < error
assert gradient_check(Tensor(a), cho_factor_net, epsilon, symmetric=True, enumerate_fn=_enumerate_fn) < error
context.set_context(mode=context.PYNATIVE_MODE)
assert gradient_check(Tensor(a), cho_factor_net, epsilon, _enumerate_fn) < error
assert gradient_check(Tensor(a), cho_factor_net, epsilon, symmetric=True, enumerate_fn=_enumerate_fn) < error
@pytest.mark.level0

View File

@ -129,8 +129,8 @@ def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32):
return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape)
def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
# some utils
def gradient_check(x, net, epsilon=1e-3, symmetric=False, enumerate_fn=onp.ndenumerate):
# Some utils
def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]:
return [_arg.asnumpy() for _arg in arg]
@ -148,12 +148,12 @@ def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
if isinstance(x, Tensor):
x = [x]
# using automatic differentiation to calculate gradient
# Using automatic differentiation to calculate gradient
grad_net = ops.GradOperation(get_all=True)(net)
x_grad = grad_net(*x)
x_grad = _tensor_to_numpy(x_grad)
# using the definition of a derivative to calculate gradient
# Using the definition of a derivative to calculate gradient
x = _tensor_to_numpy(x)
x_grad_approx = [onp.zeros_like(_x) for _x in x_grad]
for outer, _x in enumerate(x):
@ -168,6 +168,8 @@ def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
x = _add_value(x, outer, inner, epsilon)
x_grad_approx = _add_value(x_grad_approx, outer, inner, y_grad)
if symmetric:
x_grad_approx = [0.5 * (_x_grad + _x_grad.conj().T) for _x_grad in x_grad_approx]
x_grad = _flatten(x_grad)
x_grad_approx = _flatten(x_grad_approx)
numerator = onp.linalg.norm(x_grad - x_grad_approx)