From a141c8c8e3fe8385ffa4dc6a85b489012dc85e23 Mon Sep 17 00:00:00 2001 From: hezhenhao1 Date: Mon, 28 Feb 2022 17:31:32 +0800 Subject: [PATCH] Fix grads of Eigh and SolveTriangular operator in PYNATIVE mode. --- mindspore/python/mindspore/scipy/ops_grad.py | 18 ++++++++---------- .../{test_ops_grad.py => test_grad.py} | 8 ++++++-- 2 files changed, 14 insertions(+), 12 deletions(-) rename tests/st/scipy_st/{test_ops_grad.py => test_grad.py} (97%) diff --git a/mindspore/python/mindspore/scipy/ops_grad.py b/mindspore/python/mindspore/scipy/ops_grad.py index 0fad7a87aa5..9b00d05f3c1 100644 --- a/mindspore/python/mindspore/scipy/ops_grad.py +++ b/mindspore/python/mindspore/scipy/ops_grad.py @@ -121,15 +121,14 @@ def get_bprpo_eigh(self): _raise_type_error( "For 'Eigh' operation, the data type of input 'a' don't support the complex64 or complex128.") if not is_compute_v: - w, grad_w = out, dout _, v = eigh(a) - grad_a = _matmul(v * F.expand_dims(grad_w, -2), _adjoint(v)) + grad_a = _matmul(v * F.expand_dims(dout, -2), _adjoint(v)) else: - w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1] - vh_gv = _matmul(_adjoint(v), grad_v) - f = _compute_f(w) - mid_part = _diag(grad_w) + f * vh_gv - grad_a = _matmul(v, _matmul(mid_part, _adjoint(v))) + vh = _adjoint(out[1]) + vh_gv = _matmul(vh, dout[1]) + f = _compute_f(out[0]) + mid_part = _diag(dout[0]) + f * vh_gv + grad_a = _matmul(out[1], _matmul(mid_part, vh)) # The forward implementation only focus on lower part or upper part, # so we only retain the corresponding part. @@ -154,11 +153,10 @@ def get_bprpo_trsm(self): solve_triangular = SolveTriangular(is_lower, is_unit_diagonal, bp_trans) def bprop(a, b, out, dout): - x, grad_x = out, dout row_size = F.shape(a)[-2] - grad_b = solve_triangular(a, grad_x) + grad_b = solve_triangular(a, dout) grad_b_align = F.reshape(grad_b, (row_size, -1)) - x_align = F.reshape(x, (row_size, -1)) + x_align = F.reshape(out, (row_size, -1)) if bp_trans in ["T", "C"]: grad_a = _matmul(grad_b_align, _adjoint(x_align)) else: diff --git a/tests/st/scipy_st/test_ops_grad.py b/tests/st/scipy_st/test_grad.py similarity index 97% rename from tests/st/scipy_st/test_ops_grad.py rename to tests/st/scipy_st/test_grad.py index 619d3075744..6efb3213517 100644 --- a/tests/st/scipy_st/test_ops_grad.py +++ b/tests/st/scipy_st/test_grad.py @@ -152,7 +152,6 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type): Expectation: the result match gradient checking. """ onp.random.seed(0) - context.set_context(mode=context.GRAPH_MODE) dtype, epsilon, error = data_type class Net(nn.Cell): @@ -176,6 +175,9 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type): net = Net() a = create_random_rank_matrix(shape, dtype) + context.set_context(mode=context.GRAPH_MODE) + assert gradient_check(Tensor(a), net, epsilon) < error + context.set_context(mode=context.PYNATIVE_MODE) assert gradient_check(Tensor(a), net, epsilon) < error @@ -196,7 +198,6 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type): """ a_shape, b_shape = shapes onp.random.seed(0) - context.set_context(mode=context.GRAPH_MODE) dtype, epsilon, error = data_type class Net(nn.Cell): @@ -213,4 +214,7 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type): net = Net() a = (onp.random.random(a_shape) + onp.eye(a_shape[-1])).astype(dtype) b = onp.random.random(b_shape).astype(dtype) + context.set_context(mode=context.GRAPH_MODE) + assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error + context.set_context(mode=context.PYNATIVE_MODE) assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error