forked from mindspore-Ecosystem/mindspore
matrix_power:remove the test case for n < 0
This commit is contained in:
parent
4a8d98014d
commit
1dcab90aa5
|
@ -44,7 +44,7 @@ def test_matrix_power(mode):
|
||||||
net_matrix_power = NetMatrixPower()
|
net_matrix_power = NetMatrixPower()
|
||||||
|
|
||||||
for arr in arrs:
|
for arr in arrs:
|
||||||
for n in range(-2, 4):
|
for n in range(0, 4):
|
||||||
expect_out = np.linalg.matrix_power(arr, n)
|
expect_out = np.linalg.matrix_power(arr, n)
|
||||||
out = net_matrix_power(ms.Tensor(arr), n)
|
out = net_matrix_power(ms.Tensor(arr), n)
|
||||||
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)
|
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)
|
||||||
|
|
|
@ -43,7 +43,7 @@ def test_matrix_power(mode):
|
||||||
net_matrix_power = NetMatrixPower()
|
net_matrix_power = NetMatrixPower()
|
||||||
|
|
||||||
for arr in arrs:
|
for arr in arrs:
|
||||||
for n in range(-2, 4):
|
for n in range(0, 4):
|
||||||
expect_out = np.linalg.matrix_power(arr, n)
|
expect_out = np.linalg.matrix_power(arr, n)
|
||||||
out = net_matrix_power(ms.Tensor(arr), n)
|
out = net_matrix_power(ms.Tensor(arr), n)
|
||||||
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)
|
assert np.allclose(out.asnumpy(), expect_out, rtol=1e-4, atol=1e-4)
|
||||||
|
|
Loading…
Reference in New Issue