!46827 Alpha fix bug of value error cannot raise

Merge pull request !46827 from chenfei_mindspore/alpha-runtime-err-fix
This commit is contained in:
i-robot 2022-12-16 08:19:32 +00:00 committed by Gitee
commit 20fda5bb5c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 65 additions and 21 deletions

View File

@ -19,7 +19,7 @@ mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE
mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::ResolveObjectToNode
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.h:mindspore::PyExceptionInitializer::HandleExceptionPy
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc:aicpu::ARMDropOutGenMaskKernel

View File

@ -505,7 +505,7 @@ void GetTraceStackInfo(std::ostringstream &oss, bool add_title) {
// Register trace provider to LogWriter.
struct TraceProviderRegister {
TraceProviderRegister() noexcept { LogWriter::set_trace_provider(GetTraceStackInfo); }
TraceProviderRegister() noexcept { LogWriter::SetTraceProvider(GetTraceStackInfo); }
~TraceProviderRegister() = default;
} trace_provider_register;

View File

@ -74,6 +74,7 @@
#include "kernel/akg/akg_kernel_build_manager.h"
#include "kernel/graph_kernel_info.h"
#include "include/backend/data_queue/data_queue_mgr.h"
#include "pybind_api/ir/log_adapter_py.h"
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#endif

View File

@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PYBINDAPI_IR_LOGADAPTERPY_H_
#define MINDSPORE_CCSRC_PYBINDAPI_IR_LOGADAPTERPY_H_
#include "utils/log_adapter.h"
@ -24,7 +26,10 @@ namespace py = pybind11;
namespace mindspore {
class PyExceptionInitializer {
public:
PyExceptionInitializer() { mindspore::LogWriter::set_exception_handler(HandleExceptionPy); }
PyExceptionInitializer() {
MS_LOG(INFO) << "Set exception handler";
mindspore::LogWriter::SetExceptionHandler(HandleExceptionPy);
}
~PyExceptionInitializer() = default;
@ -96,3 +101,4 @@ class PyExceptionInitializer {
static PyExceptionInitializer py_exception_initializer;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PYBINDAPI_IR_LOGADAPTERPY_H_

View File

@ -167,20 +167,40 @@ static int GetSlogLevel(MsLogLevel level) {
}
#endif
void LogWriter::set_exception_handler(const ExceptionHandler &exception_handler) {
exception_handler_ = exception_handler;
LogWriter::ExceptionHandler &LogWriter::exception_handler() {
static LogWriter::ExceptionHandler g_exception_handler = nullptr;
return g_exception_handler;
}
void LogWriter::set_trace_provider(const TraceProvider &trace_provider) {
if (trace_provider_ == nullptr) {
trace_provider_ = trace_provider;
} else {
LogWriter::TraceProvider &LogWriter::trace_provider() {
static LogWriter::TraceProvider g_trace_provider = nullptr;
return g_trace_provider;
}
const LogWriter::ExceptionHandler &LogWriter::GetExceptionHandler() {
const auto &exception_handler_tmp = exception_handler();
return exception_handler_tmp;
}
void LogWriter::SetExceptionHandler(const LogWriter::ExceptionHandler &new_exception_handler) {
auto &exception_handler_tmp = exception_handler();
exception_handler_tmp = new_exception_handler;
}
const LogWriter::TraceProvider &LogWriter::GetTraceProvider() {
const auto &trace_provider_tmp = trace_provider();
return trace_provider_tmp;
}
void LogWriter::SetTraceProvider(const LogWriter::TraceProvider &new_trace_provider) {
auto &trace_provider_tmp = trace_provider();
if (trace_provider_tmp != nullptr) {
MS_LOG(INFO) << "trace provider has been set, skip.";
return;
}
trace_provider_tmp = new_trace_provider;
}
LogWriter::TraceProvider LogWriter::trace_provider() { return trace_provider_; }
static inline std::string GetEnv(const std::string &envvar) {
const char *value = std::getenv(envvar.c_str());
@ -390,16 +410,18 @@ void LogWriter::operator^(const LogStream &stream) const {
if (this_thread_max_log_level >= MsLogLevel::kException) {
RemoveLabelBeforeOutputLog(msg);
}
if (trace_provider_ != nullptr) {
trace_provider_(oss, true);
const auto &trace_provider = GetTraceProvider();
if (trace_provider != nullptr) {
trace_provider(oss, true);
}
running = false;
}
DisplayDevExceptionMessage(oss, dmsg, location_);
if (exception_handler_ != nullptr) {
exception_handler_(exception_type_, oss.str());
const auto &exception_handler = GetExceptionHandler();
if (exception_handler != nullptr) {
exception_handler(exception_type_, oss.str());
}
throw std::runtime_error(oss.str());
}

View File

@ -268,21 +268,36 @@ class MS_CORE_API LogWriter {
void operator^(const LogStream &stream) const NO_RETURN;
#endif
static void set_exception_handler(const ExceptionHandler &exception_handler);
static void set_trace_provider(const TraceProvider &trace_provider);
static TraceProvider trace_provider();
/// \brief Get the function pointer of converting exception types in c++.
///
/// \return A pointer of the function.
static const ExceptionHandler &GetExceptionHandler();
/// \brief Set the function pointer of converting exception types in c++.
///
/// \param[in] A function pointer of converting exception types in c++.
static void SetExceptionHandler(const ExceptionHandler &new_exception_handler);
/// \brief Get the function pointer of printing trace stacks.
///
/// \return A pointer of the function.
static const TraceProvider &GetTraceProvider();
/// \brief Set the function pointer of printing trace stacks.
///
/// \param[in] A function pointer of printing trace stacks.
static void SetTraceProvider(const TraceProvider &new_trace_provider);
private:
void OutputLog(const std::ostringstream &msg) const;
void RemoveLabelBeforeOutputLog(const std::ostringstream &msg) const;
static ExceptionHandler &exception_handler();
static TraceProvider &trace_provider();
LocationInfo location_;
MsLogLevel log_level_;
SubModuleId submodule_;
ExceptionType exception_type_;
inline static ExceptionHandler exception_handler_ = nullptr;
inline static TraceProvider trace_provider_ = nullptr;
};
#define MSLOG_IF(level, condition, excp_type) \