forked from mindspore-Ecosystem/mindspore
!30967 fix Cholesky bprop implement to support symmetric
Merge pull request !30967 from zhuzhongrui/pub_master2
This commit is contained in:
commit
f7068a46ea
|
@ -74,9 +74,7 @@ def get_bprop_cholesky(self):
|
|||
dout_middle = matrix_set_diag(dout_middle, middle_diag)
|
||||
dout_middle = _matrix_band_part(dout_middle, -1, 0)
|
||||
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
|
||||
grad_a = _matrix_band_part(grad_a + _adjoint(grad_a), -1, 0)
|
||||
middle_diag = 0.5 * grad_a.diagonal(0, -2, -1)
|
||||
grad_a = matrix_set_diag(grad_a, middle_diag)
|
||||
grad_a = 0.5 * (grad_a + _adjoint(grad_a))
|
||||
return (grad_a,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -53,10 +53,10 @@ def test_cholesky_grad(shape, data_type):
|
|||
cholesky_net = CholeskyNet()
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
cholesky_net(Tensor(a))
|
||||
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
|
||||
assert gradient_check(Tensor(a), cholesky_net, epsilon, symmetric=True) < error
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
cholesky_net(Tensor(a))
|
||||
assert gradient_check(Tensor(a), cholesky_net, epsilon) < error
|
||||
assert gradient_check(Tensor(a), cholesky_net, epsilon, symmetric=True) < error
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -94,9 +94,9 @@ def test_cho_factor_grad(lower, shape, data_type):
|
|||
|
||||
cho_factor_net = ChoFactorNet(lower)
|
||||
a = create_sym_pos_matrix(shape, dtype)
|
||||
assert gradient_check(Tensor(a), cho_factor_net, epsilon, _enumerate_fn) < error
|
||||
assert gradient_check(Tensor(a), cho_factor_net, epsilon, symmetric=True, enumerate_fn=_enumerate_fn) < error
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert gradient_check(Tensor(a), cho_factor_net, epsilon, _enumerate_fn) < error
|
||||
assert gradient_check(Tensor(a), cho_factor_net, epsilon, symmetric=True, enumerate_fn=_enumerate_fn) < error
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -129,8 +129,8 @@ def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32):
|
|||
return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape)
|
||||
|
||||
|
||||
def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
|
||||
# some utils
|
||||
def gradient_check(x, net, epsilon=1e-3, symmetric=False, enumerate_fn=onp.ndenumerate):
|
||||
# Some utils
|
||||
def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]:
|
||||
return [_arg.asnumpy() for _arg in arg]
|
||||
|
||||
|
@ -148,12 +148,12 @@ def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
|
|||
if isinstance(x, Tensor):
|
||||
x = [x]
|
||||
|
||||
# using automatic differentiation to calculate gradient
|
||||
# Using automatic differentiation to calculate gradient
|
||||
grad_net = ops.GradOperation(get_all=True)(net)
|
||||
x_grad = grad_net(*x)
|
||||
x_grad = _tensor_to_numpy(x_grad)
|
||||
|
||||
# using the definition of a derivative to calculate gradient
|
||||
# Using the definition of a derivative to calculate gradient
|
||||
x = _tensor_to_numpy(x)
|
||||
x_grad_approx = [onp.zeros_like(_x) for _x in x_grad]
|
||||
for outer, _x in enumerate(x):
|
||||
|
@ -168,6 +168,8 @@ def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate):
|
|||
x = _add_value(x, outer, inner, epsilon)
|
||||
x_grad_approx = _add_value(x_grad_approx, outer, inner, y_grad)
|
||||
|
||||
if symmetric:
|
||||
x_grad_approx = [0.5 * (_x_grad + _x_grad.conj().T) for _x_grad in x_grad_approx]
|
||||
x_grad = _flatten(x_grad)
|
||||
x_grad_approx = _flatten(x_grad_approx)
|
||||
numerator = onp.linalg.norm(x_grad - x_grad_approx)
|
||||
|
|
Loading…
Reference in New Issue