forked from mindspore-Ecosystem/mindspore
[ME] raise only supports some Python standard exception types.
This commit is contained in:
parent
ed67d9d4e8
commit
bc5906e630
|
@ -13,6 +13,8 @@ mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::O
|
|||
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
|
||||
mindspore/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py:__init__
|
||||
|
|
|
@ -1111,7 +1111,9 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
|
|||
if (exception_types_map.find(name_id) != exception_types_map.end()) {
|
||||
return {NewValueNode(name_id)};
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
|
||||
<< ". Raise only support some Python standard exception types: "
|
||||
<< SupportedExceptionsToString();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1125,7 +1127,9 @@ std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, co
|
|||
if (exception_types_map.find(name_id) != exception_types_map.end()) {
|
||||
return ParseException(block, args, name_id);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id << ".";
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
|
||||
<< ". Raise only support some Python standard exception types: "
|
||||
<< SupportedExceptionsToString();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
|
|
|
@ -962,6 +962,48 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg
|
|||
} catch (const py::assertion_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::assertion_error(ex);
|
||||
} catch (const py::base_exception &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::base_exception(ex);
|
||||
} catch (const py::keyboard_interrupt &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::keyboard_interrupt(ex);
|
||||
} catch (const py::stop_iteration &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::stop_iteration(ex);
|
||||
} catch (const py::overflow_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::overflow_error(ex);
|
||||
} catch (const py::zero_division_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::zero_division_error(ex);
|
||||
} catch (const py::environment_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::environment_error(ex);
|
||||
} catch (const py::io_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::io_error(ex);
|
||||
} catch (const py::os_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::os_error(ex);
|
||||
} catch (const py::import_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::import_error(ex);
|
||||
} catch (const py::memory_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::memory_error(ex);
|
||||
} catch (const py::unbound_local_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::unbound_local_error(ex);
|
||||
} catch (const py::not_implemented_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::not_implemented_error(ex);
|
||||
} catch (const py::indentation_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::indentation_error(ex);
|
||||
} catch (const py::runtime_warning &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::runtime_warning(ex);
|
||||
} catch (const std::exception &ex) {
|
||||
ReleaseResource(phase);
|
||||
// re-throw this exception to Python interpreter to handle it
|
||||
|
|
|
@ -2053,7 +2053,9 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
std::string exception_type = GetScalarStringValue(args_spec_list[0]);
|
||||
auto iter = exception_types_map.find(exception_type);
|
||||
if (iter == exception_types_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type << ".";
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type
|
||||
<< ". Raise only support some Python standard exception types: "
|
||||
<< SupportedExceptionsToString();
|
||||
}
|
||||
ExceptionType type = iter->second;
|
||||
if (args_spec_list.size() == 1) {
|
||||
|
|
|
@ -51,6 +51,48 @@ class PyExceptionInitializer {
|
|||
if (exception_type == AssertionError) {
|
||||
throw py::assertion_error(str);
|
||||
}
|
||||
if (exception_type == BaseException) {
|
||||
throw py::base_exception(str);
|
||||
}
|
||||
if (exception_type == KeyboardInterrupt) {
|
||||
throw py::keyboard_interrupt(str);
|
||||
}
|
||||
if (exception_type == StopIteration) {
|
||||
throw py::stop_iteration(str);
|
||||
}
|
||||
if (exception_type == OverflowError) {
|
||||
throw py::overflow_error(str);
|
||||
}
|
||||
if (exception_type == ZeroDivisionError) {
|
||||
throw py::zero_division_error(str);
|
||||
}
|
||||
if (exception_type == EnvironmentError) {
|
||||
throw py::environment_error(str);
|
||||
}
|
||||
if (exception_type == IOError) {
|
||||
throw py::io_error(str);
|
||||
}
|
||||
if (exception_type == OSError) {
|
||||
throw py::os_error(str);
|
||||
}
|
||||
if (exception_type == ImportError) {
|
||||
throw py::import_error(str);
|
||||
}
|
||||
if (exception_type == MemoryError) {
|
||||
throw py::memory_error(str);
|
||||
}
|
||||
if (exception_type == UnboundLocalError) {
|
||||
throw py::unbound_local_error(str);
|
||||
}
|
||||
if (exception_type == NotImplementedError) {
|
||||
throw py::not_implemented_error(str);
|
||||
}
|
||||
if (exception_type == IndentationError) {
|
||||
throw py::indentation_error(str);
|
||||
}
|
||||
if (exception_type == RuntimeWarning) {
|
||||
throw py::runtime_warning(str);
|
||||
}
|
||||
py::pybind11_fail(str);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -20,6 +20,19 @@ namespace pybind11 {
|
|||
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(name_error, PyExc_NameError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(assertion_error, PyExc_AssertionError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(base_exception, PyExc_BaseException)
|
||||
PYBIND11_RUNTIME_EXCEPTION(keyboard_interrupt, PyExc_KeyboardInterrupt)
|
||||
PYBIND11_RUNTIME_EXCEPTION(overflow_error, PyExc_OverflowError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(zero_division_error, PyExc_ZeroDivisionError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(environment_error, PyExc_EnvironmentError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(io_error, PyExc_IOError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(os_error, PyExc_OSError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(import_error, PyExc_ImportError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(memory_error, PyExc_MemoryError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(unbound_local_error, PyExc_UnboundLocalError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(not_implemented_error, PyExc_NotImplementedError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(indentation_error, PyExc_IndentationError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(runtime_warning, PyExc_RuntimeWarning)
|
||||
} // namespace pybind11
|
||||
|
||||
#endif // PYBIND_API_PYBIND_PATCH_H_
|
||||
|
|
|
@ -58,13 +58,8 @@ enum ExceptionType {
|
|||
ArgumentError,
|
||||
NotSupportError,
|
||||
NotExistsError,
|
||||
AlreadyExistsError,
|
||||
UnavailableError,
|
||||
DeviceProcessError,
|
||||
AbortedError,
|
||||
TimeOutError,
|
||||
ResourceUnavailable,
|
||||
NoPermissionError,
|
||||
IndexError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
|
@ -90,19 +85,7 @@ enum ExceptionType {
|
|||
RuntimeWarning,
|
||||
};
|
||||
|
||||
static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoExceptionType", NoExceptionType},
|
||||
{"UnknownError", UnknownError},
|
||||
{"ArgumentError", ArgumentError},
|
||||
{"NotSupportError", NotSupportError},
|
||||
{"NotExistsError", NotExistsError},
|
||||
{"AlreadyExistsError", AlreadyExistsError},
|
||||
{"UnavailableError", UnavailableError},
|
||||
{"DeviceProcessError", DeviceProcessError},
|
||||
{"AbortedError", AbortedError},
|
||||
{"TimeOutError", TimeOutError},
|
||||
{"ResourceUnavailable", ResourceUnavailable},
|
||||
{"NoPermissionError", NoPermissionError},
|
||||
{"IndexError", IndexError},
|
||||
static inline std::map<std::string, ExceptionType> exception_types_map = {{"IndexError", IndexError},
|
||||
{"ValueError", ValueError},
|
||||
{"TypeError", TypeError},
|
||||
{"KeyError", KeyError},
|
||||
|
@ -126,6 +109,20 @@ static inline std::map<std::string, ExceptionType> exception_types_map = {{"NoEx
|
|||
{"IndentationError", IndentationError},
|
||||
{"RuntimeWarning", RuntimeWarning}};
|
||||
|
||||
static inline std::string SupportedExceptionsToString() {
|
||||
std::ostringstream oss;
|
||||
size_t index = 0;
|
||||
for (auto iter = exception_types_map.begin(); iter != exception_types_map.end(); ++iter) {
|
||||
oss << iter->first;
|
||||
if (index != exception_types_map.size() - 1) {
|
||||
oss << ", ";
|
||||
}
|
||||
++index;
|
||||
}
|
||||
oss << ". ";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
struct LocationInfo {
|
||||
LocationInfo(const char *file, int line, const char *func) : file_(file), line_(line), func_(func) {}
|
||||
~LocationInfo() = default;
|
||||
|
|
|
@ -466,3 +466,49 @@ def test_raise_18():
|
|||
res = net()
|
||||
print("res:", res)
|
||||
assert "The input should not be 1." in str(raise_info_18.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_19():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
if x == 1:
|
||||
raise NotSupportError(f"The input should not be 1.")
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unsupported exception type: NotSupportError."):
|
||||
net = RaiseNet()
|
||||
res = net(1)
|
||||
print("res:", res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_20():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
if x == 1:
|
||||
raise RuntimeWarning(f"The input should not be 1.")
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeWarning, match="The input should not be 1."):
|
||||
net = RaiseNet()
|
||||
res = net(1)
|
||||
print("res:", res)
|
||||
|
|
Loading…
Reference in New Issue