!33632 [ME] Add some raise test cases.
Merge pull request !33632 from Margaret_wangrui/raise_testcase
This commit is contained in:
commit
253eafd62a
|
@ -1019,6 +1019,8 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
|
|||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
|
||||
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) {
|
||||
return {NewValueNode(name_id)};
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1031,6 +1033,8 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
|
|||
MS_LOG(DEBUG) << "The name of call node is: " << name_id;
|
||||
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) {
|
||||
return ParseException(block, args, name_id);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
|
||||
}
|
||||
}
|
||||
return {};
|
||||
|
|
|
@ -337,3 +337,132 @@ def test_raise_13():
|
|||
res = net()
|
||||
print("res:", res)
|
||||
assert "The input should not be Tensor(1)." in str(raise_info_13.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_14():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor(1)
|
||||
if x == 1:
|
||||
raise NotImplementedError("The input should not be Tensor(1).")
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_14:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Unsupported exception type: NotImplementedError." in str(raise_info_14.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_15():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = 5
|
||||
y = [1, 2, 3, 4]
|
||||
if x > len(y):
|
||||
raise IndexError("The list index out of range.")
|
||||
return y[x]
|
||||
|
||||
with pytest.raises(IndexError) as raise_info_15:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The list index out of range." in str(raise_info_15.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_16():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = [1, 2, 3, 4]
|
||||
if isinstance(x, list):
|
||||
raise TypeError("The input should not be list.")
|
||||
return x
|
||||
|
||||
with pytest.raises(TypeError) as raise_info_16:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The input should not be list." in str(raise_info_16.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_17():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
name = "name_a"
|
||||
if name == "name_a":
|
||||
raise NameError("The name should not be name_a.")
|
||||
return self.param_a
|
||||
|
||||
with pytest.raises(NameError) as raise_info_17:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The name should not be name_a." in str(raise_info_17.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_18():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RaiseNet, self).__init__()
|
||||
self.input = Tensor(1)
|
||||
|
||||
def construct(self):
|
||||
if self.input == 1:
|
||||
raise AssertionError("The input should not be 1.")
|
||||
return self.param_a
|
||||
|
||||
with pytest.raises(AssertionError) as raise_info_18:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "The input should not be 1." in str(raise_info_18.value)
|
||||
|
|
Loading…
Reference in New Issue