!33730 [ME] Add some exception types.

Merge pull request !33730 from Margaret_wangrui/raise_type
This commit is contained in:
i-robot 2022-05-06 00:53:41 +00:00 committed by Gitee
commit 8d8866fd0d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 38 additions and 14 deletions

View File

@ -1106,7 +1106,7 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
// Process raise ValueError // Process raise ValueError
if (py::isinstance<py::none>(function_ast_node)) { if (py::isinstance<py::none>(function_ast_node)) {
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id")); auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) { if (exception_types_map.find(name_id) != exception_types_map.end()) {
return {NewValueNode(name_id)}; return {NewValueNode(name_id)};
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
@ -1120,7 +1120,7 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
if (arg_type == AST_SUB_TYPE_NAME) { if (arg_type == AST_SUB_TYPE_NAME) {
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id")); auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
MS_LOG(DEBUG) << "The name of call node is: " << name_id; MS_LOG(DEBUG) << "The name of call node is: " << name_id;
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) { if (exception_types_map.find(name_id) != exception_types_map.end()) {
return ParseException(block, args, name_id); return ParseException(block, args, name_id);
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";

View File

@ -71,17 +71,25 @@ enum ExceptionType {
KeyError, KeyError,
AttributeError, AttributeError,
NameError, NameError,
AssertionError AssertionError,
BaseException,
KeyboardInterrupt,
Exception,
StopIteration,
OverflowError,
ZeroDivisionError,
EnvironmentError,
IOError,
OSError,
ImportError,
MemoryError,
UnboundLocalError,
RuntimeError,
NotImplementedError,
IndentationError,
RuntimeWarning,
}; };
// exception types
const std::vector<std::string> exception_types = {
"NoExceptionType", "UnknownError", "ArgumentError", "NotSupportError",
"NotExistsError", "AlreadyExistsError", "UnavailableError", "DeviceProcessError",
"AbortedError", "TimeOutError", "ResourceUnavailable", "NoPermissionError",
"IndexError", "ValueError", "TypeError", "KeyError",
"AttributeError", "NameError", "AssertionError"};
static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoExceptionType", NoExceptionType}, static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoExceptionType", NoExceptionType},
{"UnknownError", UnknownError}, {"UnknownError", UnknownError},
{"ArgumentError", ArgumentError}, {"ArgumentError", ArgumentError},
@ -100,7 +108,23 @@ static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoEx
{"KeyError", KeyError}, {"KeyError", KeyError},
{"AttributeError", AttributeError}, {"AttributeError", AttributeError},
{"NameError", NameError}, {"NameError", NameError},
{"AssertionError", AssertionError}}; {"AssertionError", AssertionError},
{"BaseException", BaseException},
{"KeyboardInterrupt", KeyboardInterrupt},
{"Exception", Exception},
{"StopIteration", StopIteration},
{"OverflowError", OverflowError},
{"ZeroDivisionError", ZeroDivisionError},
{"EnvironmentError", EnvironmentError},
{"IOError", IOError},
{"OSError", OSError},
{"ImportError", ImportError},
{"MemoryError", MemoryError},
{"UnboundLocalError", UnboundLocalError},
{"RuntimeError", RuntimeError},
{"NotImplementedError", NotImplementedError},
{"IndentationError", IndentationError},
{"RuntimeWarning", RuntimeWarning}};
struct LocationInfo { struct LocationInfo {
LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {} LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {}

View File

@ -354,14 +354,14 @@ def test_raise_14():
def construct(self): def construct(self):
x = Tensor(1) x = Tensor(1)
if x == 1: if x == 1:
raise NotImplementedError("The input should not be Tensor(1).") raise UserWarning("The input should not be Tensor(1).")
return x return x
with pytest.raises(RuntimeError) as raise_info_14: with pytest.raises(RuntimeError) as raise_info_14:
net = RaiseNet() net = RaiseNet()
res = net() res = net()
print("res:", res) print("res:", res)
assert "Unsupported exception type: NotImplementedError." in str(raise_info_14.value) assert "Unsupported exception type: UserWarning." in str(raise_info_14.value)
@pytest.mark.level0 @pytest.mark.level0