forked from mindspore-Ecosystem/mindspore
fix log_matrix_determinant accuracy for gpu backend.
This commit is contained in:
parent
f49dd1f0cd
commit
653efcadf6
|
@ -98,12 +98,14 @@ def test_log_matrix_determinant(data_shape, data_type):
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_x = np.random.random(data_shape).astype(data_type)
|
||||
error = 1e-6
|
||||
if data_type == np.float32:
|
||||
error = 1e-4
|
||||
benchmark_output = log_matrix_determinant_np_benchmark(input_x)
|
||||
log_matrix_determinant = LogMatrixDeterminantNet()
|
||||
output = log_matrix_determinant(Tensor(input_x))
|
||||
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error)
|
||||
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error)
|
||||
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
|
||||
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
output = log_matrix_determinant(Tensor(input_x))
|
||||
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error)
|
||||
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error)
|
||||
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
|
||||
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)
|
||||
|
|
Loading…
Reference in New Issue