[ME] raise only supports some Python standard exception types.

This commit is contained in:
Margaret_wangrui 2022-05-26 14:52:43 +08:00
parent e960f5beb7
commit c0ac64aaae
8 changed files with 185 additions and 22 deletions
.jenkins/check/config
mindspore
ccsrc
pipeline/jit
parse
pipeline.cc
static_analysis
pybind_api
core/utils
tests/st/raise

View File

@ -13,6 +13,8 @@ mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::O
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
mindspore/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py:__init__

View File

@ -1111,7 +1111,9 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
if (exception_types_map.find(name_id) != exception_types_map.end()) {
return {NewValueNode(name_id)};
} else {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
<< ". Raise only support some Python standard exception types: "
<< SupportedExceptionsToString();
}
}
@ -1125,7 +1127,9 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
if (exception_types_map.find(name_id) != exception_types_map.end()) {
return ParseException(block, args, name_id);
} else {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
<< ". Raise only support some Python standard exception types: "
<< SupportedExceptionsToString();
}
}
return {};

View File

@ -962,6 +962,45 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg
} catch (const py::assertion_error &ex) {
ReleaseResource(phase);
throw py::assertion_error(ex);
} catch (const py::base_exception &ex) {
ReleaseResource(phase);
throw py::base_exception(ex);
} catch (const py::keyboard_interrupt &ex) {
ReleaseResource(phase);
throw py::keyboard_interrupt(ex);
} catch (const py::stop_iteration &ex) {
ReleaseResource(phase);
throw py::stop_iteration(ex);
} catch (const py::overflow_error &ex) {
ReleaseResource(phase);
throw py::overflow_error(ex);
} catch (const py::zero_division_error &ex) {
ReleaseResource(phase);
throw py::zero_division_error(ex);
} catch (const py::environment_error &ex) {
ReleaseResource(phase);
throw py::environment_error(ex);
} catch (const py::io_error &ex) {
ReleaseResource(phase);
throw py::io_error(ex);
} catch (const py::os_error &ex) {
ReleaseResource(phase);
throw py::os_error(ex);
} catch (const py::memory_error &ex) {
ReleaseResource(phase);
throw py::memory_error(ex);
} catch (const py::unbound_local_error &ex) {
ReleaseResource(phase);
throw py::unbound_local_error(ex);
} catch (const py::not_implemented_error &ex) {
ReleaseResource(phase);
throw py::not_implemented_error(ex);
} catch (const py::indentation_error &ex) {
ReleaseResource(phase);
throw py::indentation_error(ex);
} catch (const py::runtime_warning &ex) {
ReleaseResource(phase);
throw py::runtime_warning(ex);
} catch (const std::exception &ex) {
ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it

View File

@ -2053,7 +2053,9 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
std::string exception_type = GetScalarStringValue(args_spec_list[0]);
auto iter = exception_types_map.find(exception_type);
if (iter == exception_types_map.end()) {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type << ".";
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type
<< ". Raise only support some Python standard exception types: "
<< SupportedExceptionsToString();
}
ExceptionType type = iter->second;
if (args_spec_list.size() == 1) {

View File

@ -51,6 +51,45 @@ class PyExceptionInitializer {
if (exception_type == AssertionError) {
throw py::assertion_error(str);
}
if (exception_type == BaseException) {
throw py::base_exception(str);
}
if (exception_type == KeyboardInterrupt) {
throw py::keyboard_interrupt(str);
}
if (exception_type == StopIteration) {
throw py::stop_iteration(str);
}
if (exception_type == OverflowError) {
throw py::overflow_error(str);
}
if (exception_type == ZeroDivisionError) {
throw py::zero_division_error(str);
}
if (exception_type == EnvironmentError) {
throw py::environment_error(str);
}
if (exception_type == IOError) {
throw py::io_error(str);
}
if (exception_type == OSError) {
throw py::os_error(str);
}
if (exception_type == MemoryError) {
throw py::memory_error(str);
}
if (exception_type == UnboundLocalError) {
throw py::unbound_local_error(str);
}
if (exception_type == NotImplementedError) {
throw py::not_implemented_error(str);
}
if (exception_type == IndentationError) {
throw py::indentation_error(str);
}
if (exception_type == RuntimeWarning) {
throw py::runtime_warning(str);
}
py::pybind11_fail(str);
}
};

View File

@ -20,6 +20,18 @@ namespace pybind11 {
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
PYBIND11_RUNTIME_EXCEPTION(name_error, PyExc_NameError)
PYBIND11_RUNTIME_EXCEPTION(assertion_error, PyExc_AssertionError)
PYBIND11_RUNTIME_EXCEPTION(base_exception, PyExc_BaseException)
PYBIND11_RUNTIME_EXCEPTION(keyboard_interrupt, PyExc_KeyboardInterrupt)
PYBIND11_RUNTIME_EXCEPTION(overflow_error, PyExc_OverflowError)
PYBIND11_RUNTIME_EXCEPTION(zero_division_error, PyExc_ZeroDivisionError)
PYBIND11_RUNTIME_EXCEPTION(environment_error, PyExc_EnvironmentError)
PYBIND11_RUNTIME_EXCEPTION(io_error, PyExc_IOError)
PYBIND11_RUNTIME_EXCEPTION(os_error, PyExc_OSError)
PYBIND11_RUNTIME_EXCEPTION(memory_error, PyExc_MemoryError)
PYBIND11_RUNTIME_EXCEPTION(unbound_local_error, PyExc_UnboundLocalError)
PYBIND11_RUNTIME_EXCEPTION(not_implemented_error, PyExc_NotImplementedError)
PYBIND11_RUNTIME_EXCEPTION(indentation_error, PyExc_IndentationError)
PYBIND11_RUNTIME_EXCEPTION(runtime_warning, PyExc_RuntimeWarning)
} // namespace pybind11
#endif // PYBIND_API_PYBIND_PATCH_H_

View File

@ -58,13 +58,8 @@ enum ExceptionType {
ArgumentError,
NotSupportError,
NotExistsError,
AlreadyExistsError,
UnavailableError,
DeviceProcessError,
AbortedError,
TimeOutError,
ResourceUnavailable,
NoPermissionError,
IndexError,
ValueError,
TypeError,
@ -90,19 +85,7 @@ enum ExceptionType {
RuntimeWarning,
};
static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoExceptionType", NoExceptionType},
{"UnknownError", UnknownError},
{"ArgumentError", ArgumentError},
{"NotSupportError", NotSupportError},
{"NotExistsError", NotExistsError},
{"AlreadyExistsError", AlreadyExistsError},
{"UnavailableError", UnavailableError},
{"DeviceProcessError", DeviceProcessError},
{"AbortedError", AbortedError},
{"TimeOutError", TimeOutError},
{"ResourceUnavailable", ResourceUnavailable},
{"NoPermissionError", NoPermissionError},
{"IndexError", IndexError},
static inline std::map<std::string, ExceptionType> exception_types_map = {{"IndexError", IndexError},
{"ValueError", ValueError},
{"TypeError", TypeError},
{"KeyError", KeyError},
@ -118,7 +101,6 @@ static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoEx
{"EnvironmentError", EnvironmentError},
{"IOError", IOError},
{"OSError", OSError},
{"ImportError", ImportError},
{"MemoryError", MemoryError},
{"UnboundLocalError", UnboundLocalError},
{"RuntimeError", RuntimeError},
@ -126,6 +108,20 @@ static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoEx
{"IndentationError", IndentationError},
{"RuntimeWarning", RuntimeWarning}};
static inline std::string SupportedExceptionsToString() {
std::ostringstream oss;
size_t index = 0;
for (auto iter = exception_types_map.begin(); iter != exception_types_map.end(); ++iter) {
oss << iter->first;
if (index != exception_types_map.size() - 1) {
oss << ", ";
}
++index;
}
oss << ". ";
return oss.str();
}
struct LocationInfo {
LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {}
~LocationInfo() = default;

View File

@ -466,3 +466,72 @@ def test_raise_18():
res = net()
print("res:", res)
assert "The input should not be 1." in str(raise_info_18.value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_19():
"""
Feature: graph raise.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x):
if x == 1:
raise NotSupportError(f"The input should not be 1.")
return x
with pytest.raises(RuntimeError, match="Unsupported exception type: NotSupportError."):
net = RaiseNet()
res = net(1)
print("res:", res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_20():
"""
Feature: graph raise.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x):
if x == 1:
raise RuntimeWarning(f"The input should not be 1.")
return x
with pytest.raises(RuntimeWarning, match="The input should not be 1."):
net = RaiseNet()
res = net(1)
print("res:", res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_raise_21():
"""
Feature: graph raise.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x):
if x == 1:
raise ImportError(f"The input should not be 1.")
return x
with pytest.raises(RuntimeError, match="Unsupported exception type: ImportError."):
net = RaiseNet()
res = net(1)
print("res:", res)