Separate test_trsm_grad case.

This commit is contained in:
hezhenhao1 2022-05-09 14:49:34 +08:00
parent 79e9d91a5c
commit c7452b1c9a
1 changed files with 39 additions and 4 deletions

View File

@ -190,10 +190,47 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [True, False])
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-3, 1e-3), (onp.float64, 1e-4, 1e-7)])
def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
def test_trsm_grad_pynative(shapes, trans, lower, unit_diagonal, data_type):
"""
Feature: ALL TO ALL
Description: test cases for grad implementation of SolveTriangular operator
Description: test cases for grad implementation of SolveTriangular operator in PYNATIVE mode.
Expectation: the result match gradient checking.
"""
a_shape, b_shape = shapes
onp.random.seed(0)
dtype, epsilon, error = data_type
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mean = ops.ReduceMean()
self.sum = ops.ReduceSum()
self.trsm = SolveTriangular(lower, unit_diagonal, trans)
def construct(self, a, b):
x = self.trsm(a, b)
return self.sum(x) + self.mean(x)
net = Net()
a = (onp.random.random(a_shape) + onp.eye(a_shape[-1])).astype(dtype)
b = onp.random.random(b_shape).astype(dtype)
context.set_context(mode=context.PYNATIVE_MODE)
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shapes', [((8, 8), (8, 8)), ((8, 8), (8, 2)), ((8, 8), (8,))])
@pytest.mark.parametrize('trans', ["N", "T", "C"])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [True, False])
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-3, 1e-3), (onp.float64, 1e-4, 1e-7)])
def test_trsm_grad_graph(shapes, trans, lower, unit_diagonal, data_type):
"""
Feature: ALL TO ALL
Description: test cases for grad implementation of SolveTriangular operator in GRAPH mode.
Expectation: the result match gradient checking.
"""
a_shape, b_shape = shapes
@ -216,5 +253,3 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
b = onp.random.random(b_shape).astype(dtype)
context.set_context(mode=context.GRAPH_MODE)
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
context.set_context(mode=context.PYNATIVE_MODE)
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error