forked from mindspore-Ecosystem/mindspore
!33730 [ME] Add some exception types.
Merge pull request !33730 from Margaret_wangrui/raise_type
This commit is contained in:
commit
8d8866fd0d
|
@ -1106,7 +1106,7 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
|
|||
// Process raise ValueError
|
||||
if (py::isinstance<py::none>(function_ast_node)) {
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
|
||||
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) {
|
||||
if (exception_types_map.find(name_id) != exception_types_map.end()) {
|
||||
return {NewValueNode(name_id)};
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
|
||||
|
@ -1120,7 +1120,7 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
|
|||
if (arg_type == AST_SUB_TYPE_NAME) {
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
|
||||
MS_LOG(DEBUG) << "The name of call node is: " << name_id;
|
||||
if (std::find(exception_types.begin(), exception_types.end(), name_id) != exception_types.end()) {
|
||||
if (exception_types_map.find(name_id) != exception_types_map.end()) {
|
||||
return ParseException(block, args, name_id);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
|
||||
|
|
|
@ -71,17 +71,25 @@ enum ExceptionType {
|
|||
KeyError,
|
||||
AttributeError,
|
||||
NameError,
|
||||
AssertionError
|
||||
AssertionError,
|
||||
BaseException,
|
||||
KeyboardInterrupt,
|
||||
Exception,
|
||||
StopIteration,
|
||||
OverflowError,
|
||||
ZeroDivisionError,
|
||||
EnvironmentError,
|
||||
IOError,
|
||||
OSError,
|
||||
ImportError,
|
||||
MemoryError,
|
||||
UnboundLocalError,
|
||||
RuntimeError,
|
||||
NotImplementedError,
|
||||
IndentationError,
|
||||
RuntimeWarning,
|
||||
};
|
||||
|
||||
// exception types
|
||||
const std::vector<std::string> exception_types = {
|
||||
"NoExceptionType", "UnknownError", "ArgumentError", "NotSupportError",
|
||||
"NotExistsError", "AlreadyExistsError", "UnavailableError", "DeviceProcessError",
|
||||
"AbortedError", "TimeOutError", "ResourceUnavailable", "NoPermissionError",
|
||||
"IndexError", "ValueError", "TypeError", "KeyError",
|
||||
"AttributeError", "NameError", "AssertionError"};
|
||||
|
||||
static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoExceptionType", NoExceptionType},
|
||||
{"UnknownError", UnknownError},
|
||||
{"ArgumentError", ArgumentError},
|
||||
|
@ -100,7 +108,23 @@ static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoEx
|
|||
{"KeyError", KeyError},
|
||||
{"AttributeError", AttributeError},
|
||||
{"NameError", NameError},
|
||||
{"AssertionError", AssertionError}};
|
||||
{"AssertionError", AssertionError},
|
||||
{"BaseException", BaseException},
|
||||
{"KeyboardInterrupt", KeyboardInterrupt},
|
||||
{"Exception", Exception},
|
||||
{"StopIteration", StopIteration},
|
||||
{"OverflowError", OverflowError},
|
||||
{"ZeroDivisionError", ZeroDivisionError},
|
||||
{"EnvironmentError", EnvironmentError},
|
||||
{"IOError", IOError},
|
||||
{"OSError", OSError},
|
||||
{"ImportError", ImportError},
|
||||
{"MemoryError", MemoryError},
|
||||
{"UnboundLocalError", UnboundLocalError},
|
||||
{"RuntimeError", RuntimeError},
|
||||
{"NotImplementedError", NotImplementedError},
|
||||
{"IndentationError", IndentationError},
|
||||
{"RuntimeWarning", RuntimeWarning}};
|
||||
|
||||
struct LocationInfo {
|
||||
LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {}
|
||||
|
|
|
@ -354,14 +354,14 @@ def test_raise_14():
|
|||
def construct(self):
|
||||
x = Tensor(1)
|
||||
if x == 1:
|
||||
raise NotImplementedError("The input should not be Tensor(1).")
|
||||
raise UserWarning("The input should not be Tensor(1).")
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_14:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Unsupported exception type: NotImplementedError." in str(raise_info_14.value)
|
||||
assert "Unsupported exception type: UserWarning." in str(raise_info_14.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue