forked from mindspore-Ecosystem/mindspore
fix log_matrix_determinant accuracy for gpu backend.
This commit is contained in:
parent
f49dd1f0cd
commit
653efcadf6
|
@ -98,12 +98,14 @@ def test_log_matrix_determinant(data_shape, data_type):
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
input_x = np.random.random(data_shape).astype(data_type)
|
input_x = np.random.random(data_shape).astype(data_type)
|
||||||
error = 1e-6
|
error = 1e-6
|
||||||
|
if data_type == np.float32:
|
||||||
|
error = 1e-4
|
||||||
benchmark_output = log_matrix_determinant_np_benchmark(input_x)
|
benchmark_output = log_matrix_determinant_np_benchmark(input_x)
|
||||||
log_matrix_determinant = LogMatrixDeterminantNet()
|
log_matrix_determinant = LogMatrixDeterminantNet()
|
||||||
output = log_matrix_determinant(Tensor(input_x))
|
output = log_matrix_determinant(Tensor(input_x))
|
||||||
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error)
|
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
|
||||||
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error)
|
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
output = log_matrix_determinant(Tensor(input_x))
|
output = log_matrix_determinant(Tensor(input_x))
|
||||||
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error)
|
np.testing.assert_allclose(output[0].asnumpy(), benchmark_output[0], rtol=error, atol=error)
|
||||||
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error)
|
np.testing.assert_allclose(output[1].asnumpy(), benchmark_output[1], rtol=error, atol=error)
|
||||||
|
|
Loading…
Reference in New Issue