!30664 Fix grads of Eigh and SolveTriangular operator in PYNATIVE mode.

Merge pull request !30664 from hezhenhao1/fix_grad
This commit is contained in:
i-robot 2022-03-01 01:19:07 +00:00 committed by Gitee
commit 545a6de696
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 14 additions and 12 deletions

View File

@ -121,15 +121,14 @@ def get_bprpo_eigh(self):
_raise_type_error(
"For 'Eigh' operation, the data type of input 'a' don't support the complex64 or complex128.")
if not is_compute_v:
w, grad_w = out, dout
_, v = eigh(a)
grad_a = _matmul(v * F.expand_dims(grad_w, -2), _adjoint(v))
grad_a = _matmul(v * F.expand_dims(dout, -2), _adjoint(v))
else:
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
vh_gv = _matmul(_adjoint(v), grad_v)
f = _compute_f(w)
mid_part = _diag(grad_w) + f * vh_gv
grad_a = _matmul(v, _matmul(mid_part, _adjoint(v)))
vh = _adjoint(out[1])
vh_gv = _matmul(vh, dout[1])
f = _compute_f(out[0])
mid_part = _diag(dout[0]) + f * vh_gv
grad_a = _matmul(out[1], _matmul(mid_part, vh))
# The forward implementation only focus on lower part or upper part,
# so we only retain the corresponding part.
@ -154,11 +153,10 @@ def get_bprpo_trsm(self):
solve_triangular = SolveTriangular(is_lower, is_unit_diagonal, bp_trans)
def bprop(a, b, out, dout):
x, grad_x = out, dout
row_size = F.shape(a)[-2]
grad_b = solve_triangular(a, grad_x)
grad_b = solve_triangular(a, dout)
grad_b_align = F.reshape(grad_b, (row_size, -1))
x_align = F.reshape(x, (row_size, -1))
x_align = F.reshape(out, (row_size, -1))
if bp_trans in ["T", "C"]:
grad_a = _matmul(grad_b_align, _adjoint(x_align))
else:

View File

@ -152,7 +152,6 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
Expectation: the result match gradient checking.
"""
onp.random.seed(0)
context.set_context(mode=context.GRAPH_MODE)
dtype, epsilon, error = data_type
class Net(nn.Cell):
@ -176,6 +175,9 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
net = Net()
a = create_random_rank_matrix(shape, dtype)
context.set_context(mode=context.GRAPH_MODE)
assert gradient_check(Tensor(a), net, epsilon) < error
context.set_context(mode=context.PYNATIVE_MODE)
assert gradient_check(Tensor(a), net, epsilon) < error
@ -196,7 +198,6 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
"""
a_shape, b_shape = shapes
onp.random.seed(0)
context.set_context(mode=context.GRAPH_MODE)
dtype, epsilon, error = data_type
class Net(nn.Cell):
@ -213,4 +214,7 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
net = Net()
a = (onp.random.random(a_shape) + onp.eye(a_shape[-1])).astype(dtype)
b = onp.random.random(b_shape).astype(dtype)
context.set_context(mode=context.GRAPH_MODE)
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
context.set_context(mode=context.PYNATIVE_MODE)
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error