forked from mindspore-Ecosystem/mindspore
Separate test_trsm_grad case.
This commit is contained in:
parent
79e9d91a5c
commit
c7452b1c9a
|
@ -190,10 +190,47 @@ def test_eigh_grad(compute_eigenvectors, lower, shape, data_type):
|
|||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [True, False])
|
||||
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-3, 1e-3), (onp.float64, 1e-4, 1e-7)])
|
||||
def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
|
||||
def test_trsm_grad_pynative(shapes, trans, lower, unit_diagonal, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for grad implementation of SolveTriangular operator
|
||||
Description: test cases for grad implementation of SolveTriangular operator in PYNATIVE mode.
|
||||
Expectation: the result match gradient checking.
|
||||
"""
|
||||
a_shape, b_shape = shapes
|
||||
onp.random.seed(0)
|
||||
dtype, epsilon, error = data_type
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.mean = ops.ReduceMean()
|
||||
self.sum = ops.ReduceSum()
|
||||
self.trsm = SolveTriangular(lower, unit_diagonal, trans)
|
||||
|
||||
def construct(self, a, b):
|
||||
x = self.trsm(a, b)
|
||||
return self.sum(x) + self.mean(x)
|
||||
|
||||
net = Net()
|
||||
a = (onp.random.random(a_shape) + onp.eye(a_shape[-1])).astype(dtype)
|
||||
b = onp.random.random(b_shape).astype(dtype)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shapes', [((8, 8), (8, 8)), ((8, 8), (8, 2)), ((8, 8), (8,))])
|
||||
@pytest.mark.parametrize('trans', ["N", "T", "C"])
|
||||
@pytest.mark.parametrize('lower', [False, True])
|
||||
@pytest.mark.parametrize('unit_diagonal', [True, False])
|
||||
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-3, 1e-3), (onp.float64, 1e-4, 1e-7)])
|
||||
def test_trsm_grad_graph(shapes, trans, lower, unit_diagonal, data_type):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for grad implementation of SolveTriangular operator in GRAPH mode.
|
||||
Expectation: the result match gradient checking.
|
||||
"""
|
||||
a_shape, b_shape = shapes
|
||||
|
@ -216,5 +253,3 @@ def test_trsm_grad(shapes, trans, lower, unit_diagonal, data_type):
|
|||
b = onp.random.random(b_shape).astype(dtype)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert gradient_check([Tensor(a), Tensor(b)], net, epsilon) < error
|
||||
|
|
Loading…
Reference in New Issue