Add test case for hypermap

This commit is contained in:
l00591931 2022-03-01 15:13:43 +08:00 committed by LiangZhibo
parent 9a5268f2be
commit ee85ce748b
1 changed files with 91 additions and 2 deletions

View File

@ -32,6 +32,11 @@ double_elements_fg = C.MultitypeFuncGraph("double_elements_fg")
def double_elements_fg_for_tensor_tuple(x, y):
return P.Tile()(x, y)
@double_elements_fg.register("Tensor", "List")
def double_elements_fg_for_tensor_list(x, y):
return x + y[0]
class HyperMapNet(nn.Cell):
def __init__(self, fg):
super(HyperMapNet, self).__init__()
@ -47,7 +52,7 @@ class HyperMapNet(nn.Cell):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_element_hypermap():
def test_single_element_hypermap_with_tensor_input():
"""
Feature: HyperMap
Description: Test whether the HyperMap with single tensor input can run successfully.
@ -70,7 +75,7 @@ def test_single_element_hypermap():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap():
def test_double_elements_hypermap_tensor_tuple_inputs():
"""
Feature: HyperMap
Description: Test whether the HyperMap with tensor and tuple inputs can run successfully.
@ -88,3 +93,87 @@ def test_double_elements_hypermap():
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap_tensor_list_inputs():
"""
Feature: HyperMap
Description: Test whether the HyperMap with tensor and list inputs can run successfully.
Expectation: success.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ([1, 2], [2, 1])
common_map = HyperMapNet(double_elements_fg)
output = common_map((x, y))
expect_output_1 = np.array([2.0, 3.0, 4.0])
expect_output_2 = np.array([6.0, 7.0, 8.0])
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], Tensor)
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_doubel_elements_hypermap_correct_mix_inputs():
"""
Feature: HyperMap
Description: Test whether the HyperMap with mix correct inputs (Tensor + Tuple and Tensor + List)
can run successfully.
Expectation: success.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ((1, 2), [2, 1])
common_map = HyperMapNet(double_elements_fg)
output = common_map((x, y))
expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
expect_output_2 = np.array([6.0, 7.0, 8.0])
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], Tensor)
assert isinstance(output[1], Tensor)
assert np.allclose(output[0].asnumpy(), expect_output_1)
assert np.allclose(output[1].asnumpy(), expect_output_2)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap_inputs_length_mismatch():
"""
Feature: HyperMap
Description: When the inputs to hypermap have different length, error will be raised.
Expectation: error.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = ((1, 2), (2, 1), (5, 6))
common_map = HyperMapNet(double_elements_fg)
with pytest.raises(Exception, match="The length of tuples in HyperMap must be the same"):
common_map((x, y))
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_double_elements_hypermap_inconsistent_inputs():
"""
Feature: HyperMap
Description: When the inputs to hypermap is inconsistent, error will be raised.
Expectation: error.
"""
x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
y = [(1, 2), (2, 1)]
common_map = HyperMapNet(double_elements_fg)
with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"):
common_map((x, y))