diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index a3292a36faf..69c62ec6d03 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1106,7 +1106,7 @@ std::vector Parser::ParseRaiseCall(const FunctionBlockPtr &block, co // Process raise ValueError if (py::isinstance(function_ast_node)) { auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); - if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) { + if (exception_types_map.find(name_id) != exception_types_map.end()) { return {NewValueNode(name_id)}; } else { MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; @@ -1120,7 +1120,7 @@ std::vector Parser::ParseRaiseCall(const FunctionBlockPtr &block, co if (arg_type == AST_SUB_TYPE_NAME) { auto name_id = py::cast(python_adapter::GetPyObjAttr(function_ast_node, "id")); MS_LOG(DEBUG) << "The name of call node is: " << name_id; - if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) { + if (exception_types_map.find(name_id) != exception_types_map.end()) { return ParseException(block, args, name_id); } else { MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 985eabce0cd..b350a434f47 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -71,17 +71,25 @@ enum ExceptionType { KeyError, AttributeError, NameError, - AssertionError + AssertionError, + BaseException, + KeyboardInterrupt, + Exception, + StopIteration, + OverflowError, + ZeroDivisionError, + EnvironmentError, + IOError, + OSError, + ImportError, + MemoryError, + UnboundLocalError, + RuntimeError, + NotImplementedError, + IndentationError, + RuntimeWarning, }; -// exception types -const std::vector exception_types = { - "NoExceptionType", "UnknownError", "ArgumentError", "NotSupportError", - "NotExistsError", "AlreadyExistsError", "UnavailableError", "DeviceProcessError", - "AbortedError", "TimeOutError", "ResourceUnavailable", "NoPermissionError", - "IndexError", "ValueError", "TypeError", "KeyError", - "AttributeError", "NameError", "AssertionError"}; - static inline std::map exception_types_map = {{"NoExceptionType", NoExceptionType}, {"UnknownError", UnknownError}, {"ArgumentError", ArgumentError}, @@ -100,7 +108,23 @@ static inline std::map exception_types_map = {{"NoEx {"KeyError", KeyError}, {"AttributeError", AttributeError}, {"NameError", NameError}, - {"AssertionError", AssertionError}}; + {"AssertionError", AssertionError}, + {"BaseException", BaseException}, + {"KeyboardInterrupt", KeyboardInterrupt}, + {"Exception", Exception}, + {"StopIteration", StopIteration}, + {"OverflowError", OverflowError}, + {"ZeroDivisionError", ZeroDivisionError}, + {"EnvironmentError", EnvironmentError}, + {"IOError", IOError}, + {"OSError", OSError}, + {"ImportError", ImportError}, + {"MemoryError", MemoryError}, + {"UnboundLocalError", UnboundLocalError}, + {"RuntimeError", RuntimeError}, + {"NotImplementedError", NotImplementedError}, + {"IndentationError", IndentationError}, + {"RuntimeWarning", RuntimeWarning}}; struct LocationInfo { LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {} diff --git a/tests/st/raise/test_graph_raise.py b/tests/st/raise/test_graph_raise.py index 2cd59deb1d2..65161a9cf54 100644 --- a/tests/st/raise/test_graph_raise.py +++ b/tests/st/raise/test_graph_raise.py @@ -354,14 +354,14 @@ def test_raise_14(): def construct(self): x = Tensor(1) if x == 1: - raise NotImplementedError("The input should not be Tensor(1).") + raise UserWarning("The input should not be Tensor(1).") return x with pytest.raises(RuntimeError) as raise_info_14: net = RaiseNet() res = net() print("res:", res) - assert "Unsupported exception type: NotImplementedError." in str(raise_info_14.value) + assert "Unsupported exception type: UserWarning." in str(raise_info_14.value) @pytest.mark.level0