diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index 3588754dae1..46682532d4e 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -18,7 +18,6 @@ #include #include -#include "pybind11/pybind11.h" #include "debug/trace.h" // namespace to support utils module definition @@ -219,16 +218,10 @@ void LogWriter::operator^(const LogStream &stream) const { trace::TraceGraphEval(); trace::GetEvalStackInfo(oss); - if (exception_type_ == IndexError) { - throw pybind11::index_error(oss.str()); + if (exception_handler_ != nullptr) { + exception_handler_(exception_type_, oss.str()); } - if (exception_type_ == ValueError) { - throw pybind11::value_error(oss.str()); - } - if (exception_type_ == TypeError) { - throw pybind11::type_error(oss.str()); - } - pybind11::pybind11_fail(oss.str()); + throw std::runtime_error(oss.str()); } static std::string GetEnv(const std::string &envvar) { diff --git a/mindspore/ccsrc/utils/log_adapter.h b/mindspore/ccsrc/utils/log_adapter.h index dfd463ee1d7..71dbf815e39 100644 --- a/mindspore/ccsrc/utils/log_adapter.h +++ b/mindspore/ccsrc/utils/log_adapter.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "./overload.h" #include "./securec.h" #ifdef USE_GLOG @@ -133,6 +134,8 @@ extern int g_ms_submodule_log_levels[] __attribute__((visibility("default"))); class LogWriter { public: + using ExceptionHandler = std::function; + LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, ExceptionType excp_type = NoExceptionType) : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} @@ -141,6 +144,8 @@ class LogWriter { void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); + static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } + private: void OutputLog(const std::ostringstream &msg) const; @@ -148,6 +153,8 @@ class LogWriter { MsLogLevel log_level_; SubModuleId submodule_; ExceptionType exception_type_; + + inline static ExceptionHandler exception_handler_ = nullptr; }; #define MSLOG_IF(level, condition, excp_type) \ diff --git a/mindspore/ccsrc/utils/log_adapter_py.cc b/mindspore/ccsrc/utils/log_adapter_py.cc new file mode 100644 index 00000000000..c4793b960bf --- /dev/null +++ b/mindspore/ccsrc/utils/log_adapter_py.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/log_adapter.h" + +#include +#include "pybind11/pybind11.h" + +namespace py = pybind11; +namespace mindspore { +class PyExceptionInitializer { + public: + PyExceptionInitializer() { mindspore::LogWriter::set_exception_handler(HandleExceptionPy); } + + ~PyExceptionInitializer() = default; + + private: + static void HandleExceptionPy(ExceptionType exception_type, const std::string &str) { + if (exception_type == IndexError) { + throw py::index_error(str); + } + if (exception_type == ValueError) { + throw py::value_error(str); + } + if (exception_type == TypeError) { + throw py::type_error(str); + } + py::pybind11_fail(str); + } +}; + +static PyExceptionInitializer py_exception_initializer; +} // namespace mindspore diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index 8ca318300a8..ce852175a68 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -127,11 +127,17 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { try { trace::ClearTraceStack(); engine_->Run(tupleSliceGraphPtr, args_spec_list); - FAIL() << "Excepted exception :Args type is wrong"; + FAIL() << "Excepted exception: Args type is wrong"; } catch (pybind11::type_error const &err) { ASSERT_TRUE(true); + } catch (std::runtime_error const &err) { + if (std::strstr(err.what(), "TypeError") != nullptr) { + ASSERT_TRUE(true); + } else { + FAIL() << "Excepted exception: Args type is wrong, message: " << err.what(); + } } catch (...) { - FAIL() << "Excepted exception :Args type is wrong"; + FAIL() << "Excepted exception: Args type is wrong"; } }