fix log_matrix_determinant accuracy for gpu backend.

This commit is contained in:
z00512249 2022-04-27 10:27:43 +08:00
parent f49dd1f0cd
commit 653efcadf6
1 changed files with 6 additions and 4 deletions

View File

@ -98,12 +98,14 @@ def test_log_matrix_determinant(data_shape, data_type):
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
input_x = np.random.random(data_shape).astype(data_type) input_x = np.random.random(data_shape).astype(data_type)
error = 1e-6 error = 1e-6
if data_type == np.float32:
error = 1e-4
benchmark_output = log_matrix_determinant_np_benchmark(input_x) benchmark_output = log_matrix_determinant_np_benchmark(input_x)
log_matrix_determinant = LogMatrixDeterminantNet() log_matrix_determinant = LogMatrixDeterminantNet()
output = log_matrix_determinant(Tensor(input_x)) output = log_matrix_determinant(Tensor(input_x))
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error) np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error) np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
output = log_matrix_determinant(Tensor(input_x)) output = log_matrix_determinant(Tensor(input_x))
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error) np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error) np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)