This commit is contained in:
zong-shuai 2022-02-25 14:12:07 +08:00
parent 4a49f43128
commit 5547c8620c
1 changed files with 14 additions and 15 deletions

View File

@ -46,7 +46,7 @@ def einsum_test_cases(nptype, loss):
test_cases = [["abcd->dacb", [[2, 3, 1, 1]]],
["ijk->ik", [[1, 2, 3]]],
["ij,ij->ij", [[2, 3], [2, 3]]],
["ij,kl->ijkl", [[1, 2], [3, 4]]],
["ij,kl->ijkl", [[3, 2], [2, 3]]],
["ij,jk->ik", [[3, 2], [2, 3]]]
]
for cur_case in test_cases:
@ -55,29 +55,17 @@ def einsum_test_cases(nptype, loss):
ms_data = []
np_data = []
for cur_shape in shapes:
cur_data = np.random.randn(*cur_shape).astype(np.float64)
cur_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(cur_shape).astype(np.float64)
ms_data.append(Tensor(cur_data.astype(nptype)))
np_data.append(cur_data)
net = Einsum(equation)
ms_out = net(*ms_data)
np_out = np.einsum(equation, *np_data)
assert np.allclose(ms_out.asnumpy(), np_out.astype(nptype), loss, loss)
np_dout = np.random.randn(*np_out.shape).astype(nptype)
grad_net = EinsumGrad(equation)
ms_dx = grad_net(*ms_data, Tensor(np_dout))
ms_dx = grad_net(*ms_data, Tensor(np_out.astype(nptype)))
print(ms_dx)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_einsum_graph_float32():
"""
Feature: test transpose/ reduce_sum/dot/mul/transpose_with_ell/batchmatmul
Description: test the accuracy and precision of the preceding test cases in float32 types
Expectation: the diff between the result and the operator of np.einsum is within the loss range
"""
einsum_test_cases(np.float32, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -89,6 +77,17 @@ def test_einsum_graph_float16():
"""
einsum_test_cases(np.float16, 1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_einsum_graph_float32():
"""
Feature: test transpose/ reduce_sum/dot/mul/transpose_with_ell/batchmatmul
Description: test the accuracy and precision of the preceding test cases in float32 types
Expectation: the diff between the result and the operator of np.einsum is within the loss range
"""
einsum_test_cases(np.float32, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard