!33632 [ME] Add some raise test cases.

Merge pull request !33632 from Margaret_wangrui/raise_testcase
This commit is contained in:
i-robot 2022-04-27 11:00:15 +00:00 committed by Gitee
commit 253eafd62a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 133 additions and 0 deletions

View File

@ -1019,6 +1019,8 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) {
return {NewValueNode(name_id)};
} else {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
}
}
@ -1031,6 +1033,8 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
MS_LOG(DEBUG) << "The name of call node is: " << name_id;
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) {
return ParseException(block, args, name_id);
} else {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
}
}
return {};

View File

@ -337,3 +337,132 @@ def test_raise_13():
res = net()
print("res:", res)
assert "The input should not be Tensor(1)." in str(raise_info_13.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_14():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = Tensor(1)
if x == 1:
raise NotImplementedError("The input should not be Tensor(1).")
return x
with pytest.raises(RuntimeError) as raise_info_14:
net = RaiseNet()
res = net()
print("res:", res)
assert "Unsupported exception type: NotImplementedError." in str(raise_info_14.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_15():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = 5
y = [1, 2, 3, 4]
if x > len(y):
raise IndexError("The list index out of range.")
return y[x]
with pytest.raises(IndexError) as raise_info_15:
net = RaiseNet()
res = net()
print("res:", res)
assert "The list index out of range." in str(raise_info_15.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_16():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
x = [1, 2, 3, 4]
if isinstance(x, list):
raise TypeError("The input should not be list.")
return x
with pytest.raises(TypeError) as raise_info_16:
net = RaiseNet()
res = net()
print("res:", res)
assert "The input should not be list." in str(raise_info_16.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_17():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self):
name = "name_a"
if name == "name_a":
raise NameError("The name should not be name_a.")
return self.param_a
with pytest.raises(NameError) as raise_info_17:
net = RaiseNet()
res = net()
print("res:", res)
assert "The name should not be name_a." in str(raise_info_17.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_18():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def __init__(self):
super(RaiseNet, self).__init__()
self.input = Tensor(1)
def construct(self):
if self.input == 1:
raise AssertionError("The input should not be 1.")
return self.param_a
with pytest.raises(AssertionError) as raise_info_18:
net = RaiseNet()
res = net()
print("res:", res)
assert "The input should not be 1." in str(raise_info_18.value)