From c0ac64aaaec62a804cb53e235299cf128616192a Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Thu, 26 May 2022 14:52:43 +0800 Subject: [PATCH] [ME] raise only supports some Python standard exception types. --- .jenkins/check/config/whitelizard.txt | 2 + mindspore/ccsrc/pipeline/jit/parse/parse.cc | 8 ++- mindspore/ccsrc/pipeline/jit/pipeline.cc | 39 +++++++++++ .../pipeline/jit/static_analysis/prim.cc | 4 +- .../ccsrc/pybind_api/ir/log_adapter_py.cc | 39 +++++++++++ mindspore/ccsrc/pybind_api/pybind_patch.h | 12 ++++ mindspore/core/utils/log_adapter.h | 34 ++++----- tests/st/raise/test_graph_raise.py | 69 +++++++++++++++++++ 8 files changed, 185 insertions(+), 22 deletions(-) diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 373f6d85cab..d1db90e6dcc 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -13,6 +13,8 @@ mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::O mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE +mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile +mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode mindspore/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py:__init__ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 39ce0449268..913cf280ce9 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1111,7 +1111,9 @@ std::vector Parser::ParseRaiseCall(const FunctionBlockPtr &block, co if (exception_types_map.find(name_id) != exception_types_map.end()) { return {NewValueNode(name_id)}; } else { - MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; + MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id + << ". Raise only support some Python standard exception types: " + << SupportedExceptionsToString(); } } @@ -1125,7 +1127,9 @@ std::vector Parser::ParseRaiseCall(const FunctionBlockPtr &block, co if (exception_types_map.find(name_id) != exception_types_map.end()) { return ParseException(block, args, name_id); } else { - MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << "."; + MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id + << ". Raise only support some Python standard exception types: " + << SupportedExceptionsToString(); } } return {}; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 8799a509d53..04144278843 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -962,6 +962,45 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg } catch (const py::assertion_error &ex) { ReleaseResource(phase); throw py::assertion_error(ex); + } catch (const py::base_exception &ex) { + ReleaseResource(phase); + throw py::base_exception(ex); + } catch (const py::keyboard_interrupt &ex) { + ReleaseResource(phase); + throw py::keyboard_interrupt(ex); + } catch (const py::stop_iteration &ex) { + ReleaseResource(phase); + throw py::stop_iteration(ex); + } catch (const py::overflow_error &ex) { + ReleaseResource(phase); + throw py::overflow_error(ex); + } catch (const py::zero_division_error &ex) { + ReleaseResource(phase); + throw py::zero_division_error(ex); + } catch (const py::environment_error &ex) { + ReleaseResource(phase); + throw py::environment_error(ex); + } catch (const py::io_error &ex) { + ReleaseResource(phase); + throw py::io_error(ex); + } catch (const py::os_error &ex) { + ReleaseResource(phase); + throw py::os_error(ex); + } catch (const py::memory_error &ex) { + ReleaseResource(phase); + throw py::memory_error(ex); + } catch (const py::unbound_local_error &ex) { + ReleaseResource(phase); + throw py::unbound_local_error(ex); + } catch (const py::not_implemented_error &ex) { + ReleaseResource(phase); + throw py::not_implemented_error(ex); + } catch (const py::indentation_error &ex) { + ReleaseResource(phase); + throw py::indentation_error(ex); + } catch (const py::runtime_warning &ex) { + ReleaseResource(phase); + throw py::runtime_warning(ex); } catch (const std::exception &ex) { ReleaseResource(phase); // re-throw this exception to Python interpreter to handle it diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index bfefee182d8..181fa3770a9 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -2053,7 +2053,9 @@ class RaiseEvaluator : public TransitionPrimEvaluator { std::string exception_type = GetScalarStringValue(args_spec_list[0]); auto iter = exception_types_map.find(exception_type); if (iter == exception_types_map.end()) { - MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type << "."; + MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type + << ". Raise only support some Python standard exception types: " + << SupportedExceptionsToString(); } ExceptionType type = iter->second; if (args_spec_list.size() == 1) { diff --git a/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc b/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc index c93622c48ae..49da189af1e 100644 --- a/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc @@ -51,6 +51,45 @@ class PyExceptionInitializer { if (exception_type == AssertionError) { throw py::assertion_error(str); } + if (exception_type == BaseException) { + throw py::base_exception(str); + } + if (exception_type == KeyboardInterrupt) { + throw py::keyboard_interrupt(str); + } + if (exception_type == StopIteration) { + throw py::stop_iteration(str); + } + if (exception_type == OverflowError) { + throw py::overflow_error(str); + } + if (exception_type == ZeroDivisionError) { + throw py::zero_division_error(str); + } + if (exception_type == EnvironmentError) { + throw py::environment_error(str); + } + if (exception_type == IOError) { + throw py::io_error(str); + } + if (exception_type == OSError) { + throw py::os_error(str); + } + if (exception_type == MemoryError) { + throw py::memory_error(str); + } + if (exception_type == UnboundLocalError) { + throw py::unbound_local_error(str); + } + if (exception_type == NotImplementedError) { + throw py::not_implemented_error(str); + } + if (exception_type == IndentationError) { + throw py::indentation_error(str); + } + if (exception_type == RuntimeWarning) { + throw py::runtime_warning(str); + } py::pybind11_fail(str); } }; diff --git a/mindspore/ccsrc/pybind_api/pybind_patch.h b/mindspore/ccsrc/pybind_api/pybind_patch.h index 32fc10eb993..11424fc5c9c 100644 --- a/mindspore/ccsrc/pybind_api/pybind_patch.h +++ b/mindspore/ccsrc/pybind_api/pybind_patch.h @@ -20,6 +20,18 @@ namespace pybind11 { PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) PYBIND11_RUNTIME_EXCEPTION(name_error, PyExc_NameError) PYBIND11_RUNTIME_EXCEPTION(assertion_error, PyExc_AssertionError) +PYBIND11_RUNTIME_EXCEPTION(base_exception, PyExc_BaseException) +PYBIND11_RUNTIME_EXCEPTION(keyboard_interrupt, PyExc_KeyboardInterrupt) +PYBIND11_RUNTIME_EXCEPTION(overflow_error, PyExc_OverflowError) +PYBIND11_RUNTIME_EXCEPTION(zero_division_error, PyExc_ZeroDivisionError) +PYBIND11_RUNTIME_EXCEPTION(environment_error, PyExc_EnvironmentError) +PYBIND11_RUNTIME_EXCEPTION(io_error, PyExc_IOError) +PYBIND11_RUNTIME_EXCEPTION(os_error, PyExc_OSError) +PYBIND11_RUNTIME_EXCEPTION(memory_error, PyExc_MemoryError) +PYBIND11_RUNTIME_EXCEPTION(unbound_local_error, PyExc_UnboundLocalError) +PYBIND11_RUNTIME_EXCEPTION(not_implemented_error, PyExc_NotImplementedError) +PYBIND11_RUNTIME_EXCEPTION(indentation_error, PyExc_IndentationError) +PYBIND11_RUNTIME_EXCEPTION(runtime_warning, PyExc_RuntimeWarning) } // namespace pybind11 #endif // PYBIND_API_PYBIND_PATCH_H_ diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index b350a434f47..1c5b479ac4b 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -58,13 +58,8 @@ enum ExceptionType { ArgumentError, NotSupportError, NotExistsError, - AlreadyExistsError, - UnavailableError, DeviceProcessError, AbortedError, - TimeOutError, - ResourceUnavailable, - NoPermissionError, IndexError, ValueError, TypeError, @@ -90,19 +85,7 @@ enum ExceptionType { RuntimeWarning, }; -static inline std::map exception_types_map = {{"NoExceptionType", NoExceptionType}, - {"UnknownError", UnknownError}, - {"ArgumentError", ArgumentError}, - {"NotSupportError", NotSupportError}, - {"NotExistsError", NotExistsError}, - {"AlreadyExistsError", AlreadyExistsError}, - {"UnavailableError", UnavailableError}, - {"DeviceProcessError", DeviceProcessError}, - {"AbortedError", AbortedError}, - {"TimeOutError", TimeOutError}, - {"ResourceUnavailable", ResourceUnavailable}, - {"NoPermissionError", NoPermissionError}, - {"IndexError", IndexError}, +static inline std::map exception_types_map = {{"IndexError", IndexError}, {"ValueError", ValueError}, {"TypeError", TypeError}, {"KeyError", KeyError}, @@ -118,7 +101,6 @@ static inline std::map exception_types_map = {{"NoEx {"EnvironmentError", EnvironmentError}, {"IOError", IOError}, {"OSError", OSError}, - {"ImportError", ImportError}, {"MemoryError", MemoryError}, {"UnboundLocalError", UnboundLocalError}, {"RuntimeError", RuntimeError}, @@ -126,6 +108,20 @@ static inline std::map exception_types_map = {{"NoEx {"IndentationError", IndentationError}, {"RuntimeWarning", RuntimeWarning}}; +static inline std::string SupportedExceptionsToString() { + std::ostringstream oss; + size_t index = 0; + for (auto iter = exception_types_map.begin(); iter != exception_types_map.end(); ++iter) { + oss << iter->first; + if (index != exception_types_map.size() - 1) { + oss << ", "; + } + ++index; + } + oss << ". "; + return oss.str(); +} + struct LocationInfo { LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {} ~LocationInfo() = default; diff --git a/tests/st/raise/test_graph_raise.py b/tests/st/raise/test_graph_raise.py index 65161a9cf54..32cdeff39bf 100644 --- a/tests/st/raise/test_graph_raise.py +++ b/tests/st/raise/test_graph_raise.py @@ -466,3 +466,72 @@ def test_raise_18(): res = net() print("res:", res) assert "The input should not be 1." in str(raise_info_18.value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_19(): + """ + Feature: graph raise. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self, x): + if x == 1: + raise NotSupportError(f"The input should not be 1.") + return x + + with pytest.raises(RuntimeError, match="Unsupported exception type: NotSupportError."): + net = RaiseNet() + res = net(1) + print("res:", res) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_20(): + """ + Feature: graph raise. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self, x): + if x == 1: + raise RuntimeWarning(f"The input should not be 1.") + return x + + with pytest.raises(RuntimeWarning, match="The input should not be 1."): + net = RaiseNet() + res = net(1) + print("res:", res) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_raise_21(): + """ + Feature: graph raise. + Description: Test raise. + Expectation: No exception. + """ + class RaiseNet(nn.Cell): + def construct(self, x): + if x == 1: + raise ImportError(f"The input should not be 1.") + return x + + with pytest.raises(RuntimeError, match="Unsupported exception type: ImportError."): + net = RaiseNet() + res = net(1) + print("res:", res)