forked from mindspore-Ecosystem/mindspore
!30664 Fix grads of Eigh and SolveTriangular operator in PYNATIVE mode.
Merge pull request !30664 from hezhenhao1/fix_grad
This commit is contained in:
commit
545a6de696
|
@ -121,15 +121,14 @@ def get_bprpo_eigh(self):
|
|||
_raise_type_error(
|
||||
"For 'Eigh' operation, the data type of input 'a' don't support the complex64 or complex128.")
|
||||
if not is_compute_v:
|
||||
w, grad_w = out, dout
|
||||
_, v = eigh(a)
|
||||
grad_a = _matmul(v * F.expand_dims(grad_w, -2), _adjoint(v))
|
||||
grad_a = _matmul(v * F.expand_dims(dout, -2), _adjoint(v))
|
||||
else:
|
||||
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
|
||||
vh_gv = _matmul(_adjoint(v), grad_v)
|
||||
f = _compute_f(w)
|
||||
mid_part = _diag(grad_w) + f * vh_gv
|
||||
grad_a = _matmul(v, _matmul(mid_part, _adjoint(v)))
|
||||
vh = _adjoint(out[1])
|
||||
vh_gv = _matmul(vh, dout[1])
|
||||
f = _compute_f(out[0])
|
||||
mid_part = _diag(dout[0]) + f * vh_gv
|
||||
grad_a = _matmul(out[1], _matmul(mid_part, vh))
|
||||
|
||||
# The forward implementation only focus on lower part or upper part,
|
||||
# so we only retain the corresponding part.
|
||||
|
@ -154,11 +153,10 @@ def get_bprpo_trsm(self):
|
|||
solve_triangular = SolveTriangular(is_lower, is_unit_diagonal, bp_trans)
|
||||
|
||||
def bprop(a, b, out, dout):
|
||||
x, grad_x = out, dout
|
||||
row_size = F.shape(a)[-2]
|
||||
grad_b = solve_triangular(a, grad_x)
|
||||
grad_b = solve_triangular(a, dout)
|
||||
grad_b_align = F.reshape(grad_b, (row_size, -1))
|
||||
x_align = F.reshape(x, (row_size, -1))
|
||||
x_align = F.reshape(out, (row_size, -1))
|
||||
if bp_trans in ["T", "C"]:
|
||||
grad_a = _matmul(grad_b_align, _adjoint(x_align))
|
||||
else:
|
||||
|
|
|
@ -152,7 +152,6 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
|
|||
Expectation: the result match gradient checking.
|
||||
"""
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
dtype, epsilon, error = data_type
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -176,6 +175,9 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
|
|||
|
||||
net = Net()
|
||||
a = create_random_rank_matrix(shape, dtype)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert gradient_check(Tensor(a), net, epsilon) < error
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert gradient_check(Tensor(a), net, epsilon) < error
|
||||
|
||||
|
||||
|
@ -196,7 +198,6 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
|
|||
"""
|
||||
a_shape, b_shape = shapes
|
||||
onp.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
dtype, epsilon, error = data_type
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -213,4 +214,7 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
|
|||
net = Net()
|
||||
a = (onp.random.random(a_shape) + onp.eye(a_shape[-1])).astype(dtype)
|
||||
b = onp.random.random(b_shape).astype(dtype)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
|
Loading…
Reference in New Issue