From 653efcadf6547979529668ee3519a5c8b7d43d6b Mon Sep 17 00:00:00 2001 From: z00512249 Date: Wed, 27 Apr 2022 10:27:43 +0800 Subject: [PATCH] fix log_matrix_determinant accuracy for gpu backend. --- tests/st/ops/gpu/test_matrix_determinant_op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/st/ops/gpu/test_matrix_determinant_op.py b/tests/st/ops/gpu/test_matrix_determinant_op.py index 4de314c91c2..ab1e49edb81 100644 --- a/tests/st/ops/gpu/test_matrix_determinant_op.py +++ b/tests/st/ops/gpu/test_matrix_determinant_op.py @@ -98,12 +98,14 @@ def test_log_matrix_determinant(data_shape, data_type): context.set_context(mode=context.GRAPH_MODE) input_x = np.random.random(data_shape).astype(data_type) error = 1e-6 + if data_type == np.float32: + error = 1e-4 benchmark_output = log_matrix_determinant_np_benchmark(input_x) log_matrix_determinant = LogMatrixDeterminantNet() output = log_matrix_determinant(Tensor(input_x)) - np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error) - np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error) + np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error) + np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error) context.set_context(mode=context.PYNATIVE_MODE) output = log_matrix_determinant(Tensor(input_x)) - np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error) - np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error) + np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error) + np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)