forked from mindspore-Ecosystem/mindspore
!29917 add julia custom op error log
Merge pull request !29917 from r1chardf1d0/master
This commit is contained in:
commit
1168b59baa
|
@ -34,9 +34,9 @@ constexpr auto kLibJulia = "libjulia.so";
|
|||
typedef struct _jl_value_t jl_value_t;
|
||||
typedef jl_value_t jl_function_t;
|
||||
typedef struct _jl_module_t jl_module_t;
|
||||
typedef struct _jl_sym_t jl_sym_t;
|
||||
typedef struct _jl_datatype_t jl_datatype_t;
|
||||
typedef struct _jl_array_t jl_array_t;
|
||||
typedef struct _jl_sym_t jl_sym_t;
|
||||
|
||||
#define GET_HOOK(func, rt, ...) GET_HOOK_INNER(func, _, rt, __VA_ARGS__)
|
||||
#define GET_HOOK_INNER(func, _, rt, ...) \
|
||||
|
@ -165,8 +165,6 @@ class JuliaAPI {
|
|||
jl_apply_array_type_ = nullptr;
|
||||
jl_ptr_to_array_ = nullptr;
|
||||
jl_typeof_str_ = nullptr;
|
||||
jl_stderr_obj_ = nullptr;
|
||||
jl_current_exception_ = nullptr;
|
||||
jl_ver_major_ = nullptr;
|
||||
jl_ver_minor_ = nullptr;
|
||||
}
|
||||
|
@ -238,22 +236,49 @@ class JuliaAPI {
|
|||
GET_HOOK(jl_apply_array_type, jl_value_t *, jl_value_t *, size_t);
|
||||
GET_HOOK(jl_ptr_to_array, jl_array_t *, jl_value_t *, void *, jl_value_t *, int);
|
||||
GET_HOOK(jl_typeof_str, const char *, jl_value_t *);
|
||||
GET_HOOK(jl_stderr_obj, jl_value_t *, void);
|
||||
GET_HOOK(jl_current_exception, jl_value_t *, void);
|
||||
#else
|
||||
suc = false;
|
||||
#endif
|
||||
return suc;
|
||||
}
|
||||
|
||||
void ErrorMsg(jl_value_t *ex) {
|
||||
auto errtype = JlTypeOfStr(ex);
|
||||
MS_LOG(ERROR) << "Got a julia error! Err type: " << errtype;
|
||||
jl_module_t *base = reinterpret_cast<jl_module_t *>(JlEvalString("Main.Base"));
|
||||
if (!base) {
|
||||
MS_LOG(ERROR) << "Could not load julia module base.";
|
||||
return;
|
||||
}
|
||||
jl_function_t *showerror = JlGetFunction(base, "showerror");
|
||||
if (!showerror) {
|
||||
MS_LOG(ERROR) << "Could not load julia function showerror.";
|
||||
return;
|
||||
}
|
||||
// Base.showerror(stderr, ex)
|
||||
std::vector<jl_value_t *> args{reinterpret_cast<jl_value_t *>(Core("stderr")), ex};
|
||||
constexpr size_t args_num = 2;
|
||||
JlEvalString("print(\"\\n====================JULIA ERROR====================\\n\")");
|
||||
JlCall(showerror, &args[0], args_num);
|
||||
JlEvalString("print(\"\\n===================================================\\n\")");
|
||||
}
|
||||
|
||||
int RunJuliaKernel() {
|
||||
// include julia file
|
||||
JlEvalString("Base.include(Main, \"" + file_ + "\")");
|
||||
// using julia module
|
||||
JlEvalString("using Main." + module_);
|
||||
jl_module_t *jmod = reinterpret_cast<jl_module_t *>(JlEvalString("Main." + module_));
|
||||
if (!jmod) {
|
||||
MS_LOG(ERROR) << "Could not load julia module: " << module_;
|
||||
return -1;
|
||||
}
|
||||
// get julia function from module
|
||||
jl_function_t *jfunc = JlGetFunction(jmod, func_);
|
||||
if (!jfunc) {
|
||||
MS_LOG(ERROR) << "Could not load julia function: " << func_;
|
||||
return -1;
|
||||
}
|
||||
// convert kernel inputs to julia type
|
||||
std::vector<jl_value_t *> args(nparam_);
|
||||
for (int i = 0; i < nparam_; i++) {
|
||||
|
@ -261,19 +286,9 @@ class JuliaAPI {
|
|||
}
|
||||
// call the julia function
|
||||
JlCall(jfunc, &args[0], nparam_);
|
||||
if (JlExceptionOccurred()) {
|
||||
MS_LOG(EXCEPTION) << JlTypeOfStr(JlExceptionOccurred());
|
||||
auto errs = JlStdErrObj();
|
||||
if (errs) {
|
||||
JlEvalString("using Main.Base");
|
||||
auto base = reinterpret_cast<jl_module_t *>(JlEvalString("Main.Base"));
|
||||
auto show = JlGetFunction(base, "show");
|
||||
if (show) {
|
||||
std::vector<jl_value_t *> err_args{errs, JlCurrentException()};
|
||||
constexpr int arg_num = 2;
|
||||
JlCall(show, &err_args[0], arg_num);
|
||||
}
|
||||
}
|
||||
auto ex = JlExceptionOccurred();
|
||||
if (ex) {
|
||||
ErrorMsg(ex);
|
||||
return -1;
|
||||
}
|
||||
JlAtexitHook(0);
|
||||
|
@ -300,11 +315,7 @@ class JuliaAPI {
|
|||
return jl_ptr_to_array_(atype, data, dims, own_buffer);
|
||||
}
|
||||
|
||||
const char *JlTypeOfStr(jl_value_t *v) { return jl_typeof_str_(v); }
|
||||
|
||||
jl_value_t *JlStdErrObj() { return jl_stderr_obj_(); }
|
||||
|
||||
jl_value_t *JlCurrentException() { return jl_current_exception_(); }
|
||||
std::string JlTypeOfStr(jl_value_t *v) { return jl_typeof_str_(v); }
|
||||
|
||||
int JlVerMajor() { return jl_ver_major_(); }
|
||||
|
||||
|
@ -382,8 +393,6 @@ class JuliaAPI {
|
|||
jl_value_t *(*jl_apply_array_type_)(jl_value_t *, size_t);
|
||||
jl_array_t *(*jl_ptr_to_array_)(jl_value_t *, void *, jl_value_t *, int);
|
||||
const char *(*jl_typeof_str_)(jl_value_t *);
|
||||
jl_value_t *(*jl_stderr_obj_)(void);
|
||||
jl_value_t *(*jl_current_exception_)(void);
|
||||
int (*jl_ver_major_)(void);
|
||||
int (*jl_ver_minor_)(void);
|
||||
};
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import platform
|
||||
import pytest
|
||||
|
||||
|
||||
|
@ -11,6 +12,9 @@ def test_julia():
|
|||
Description: run julia_cases
|
||||
Expectation: res == 0
|
||||
"""
|
||||
system = platform.system()
|
||||
if system != 'Linux':
|
||||
pass
|
||||
res = os.system('sh julia_run.sh')
|
||||
if res != 0:
|
||||
assert False, 'julia test fail'
|
||||
|
|
Loading…
Reference in New Issue