!33615 fix log_matrix_determinant accuracy for gpu backend.

Merge pull request !33615 from zhuzhongrui/pub_master2
This commit is contained in:
i-robot 2022-04-27 05:55:57 +00:00 committed by Gitee
commit 71d55684e1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 4 deletions

View File

@ -98,12 +98,14 @@ def test_log_matrix_determinant(data_shape, data_type):
context.set_context(mode=context.GRAPH_MODE)
input_x = np.random.random(data_shape).astype(data_type)
error = 1e-6
if data_type == np.float32:
error = 1e-4
benchmark_output = log_matrix_determinant_np_benchmark(input_x)
log_matrix_determinant = LogMatrixDeterminantNet()
output = log_matrix_determinant(Tensor(input_x))
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error)
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error)
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)
context.set_context(mode=context.PYNATIVE_MODE)
output = log_matrix_determinant(Tensor(input_x))
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error)
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error)
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)