diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 47d288612f0..e8b14d6dc86 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1019,6 +1019,8 @@ std::vector Parser::ParseRaiseCall(const FunctionBlockPtr &block, co auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) { return {NewValueNode(name_id)}; + } else { + MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; } } @@ -1031,6 +1033,8 @@ std::vector Parser::ParseRaiseCall(const FunctionBlockPtr &block, co MS_LOG(DEBUG) << "The name of call node is: " << name_id; if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) { return ParseException(block, args, name_id); + } else { + MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; } } return {}; diff --git a/tests/st/raise/test_graph_raise.py b/tests/st/raise/test_graph_raise.py index fe04615a540..2cd59deb1d2 100644 --- a/tests/st/raise/test_graph_raise.py +++ b/tests/st/raise/test_graph_raise.py @@ -337,3 +337,132 @@ def test_raise_13(): res = net() print("res:", res) assert "The input should not be Tensor(1)." in str(raise_info_13.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_14(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = Tensor(1) + if x == 1: + raise NotImplementedError("The input should not be Tensor(1).") + return x + + with pytest.raises(RuntimeError) as raise_info_14: + net = RaiseNet() + res = net() + print("res:", res) + assert "Unsupported exception type: NotImplementedError." in str(raise_info_14.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_15(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = 5 + y = [1, 2, 3, 4] + if x > len(y): + raise IndexError("The list index out of range.") + return y[x] + + with pytest.raises(IndexError) as raise_info_15: + net = RaiseNet() + res = net() + print("res:", res) + assert "The list index out of range." in str(raise_info_15.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_16(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + x = [1, 2, 3, 4] + if isinstance(x, list): + raise TypeError("The input should not be list.") + return x + + with pytest.raises(TypeError) as raise_info_16: + net = RaiseNet() + res = net() + print("res:", res) + assert "The input should not be list." in str(raise_info_16.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_17(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self): + name = "name_a" + if name == "name_a": + raise NameError("The name should not be name_a.") + return self.param_a + + with pytest.raises(NameError) as raise_info_17: + net = RaiseNet() + res = net() + print("res:", res) + assert "The name should not be name_a." in str(raise_info_17.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_18(): + """ + Feature: graph raise by JIT Fallback. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def __init__(self): + super(RaiseNet, self).__init__() + self.input = Tensor(1) + + def construct(self): + if self.input == 1: + raise AssertionError("The input should not be 1.") + return self.param_a + + with pytest.raises(AssertionError) as raise_info_18: + net = RaiseNet() + res = net() + print("res:", res) + assert "The input should not be 1." in str(raise_info_18.value)