forked from mindspore-Ecosystem/mindspore
add gpu cholesky_grad st case
This commit is contained in:
parent
9390d67adf
commit
1e5d22a7fa
|
@ -114,7 +114,6 @@ class Cholesky(PrimitiveWithInfer):
|
|||
super().__init__("Cholesky")
|
||||
self.init_prim_io_names(inputs=['a'], outputs=['l'])
|
||||
self.clean = validator.check_value_type("clean", clean, [bool], self.name)
|
||||
self.clean = clean
|
||||
self.add_prim_attr('clean', self.clean)
|
||||
|
||||
def __infer__(self, a):
|
||||
|
|
|
@ -24,6 +24,7 @@ from tests.st.scipy_st.utils import create_random_rank_matrix, create_sym_pos_ma
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(8, 8)])
|
||||
@pytest.mark.parametrize('data_type', [(onp.float32, 1e-2, 1e-3), (onp.float64, 1e-4, 1e-7)])
|
||||
|
@ -41,7 +42,7 @@ def test_cholesky_grad(shape, data_type):
|
|||
def __init__(self):
|
||||
super(CholeskyNet, self).__init__()
|
||||
self.mean = ops.ReduceMean()
|
||||
# args clean not supports grad right now, just default to clean.
|
||||
# Input arg clean not supports grad right now, just default clean to True.
|
||||
self.cholesky = Cholesky(clean=True)
|
||||
|
||||
def construct(self, a):
|
||||
|
|
Loading…
Reference in New Issue