forked from mindspore-Ecosystem/mindspore
commit
f5010ec8b8
|
@ -81,7 +81,7 @@ class CheckAllFormatsSame : public Validator {
|
|||
|
||||
class CheckAttr : public Validator {
|
||||
public:
|
||||
CheckAttr(const std::initializer_list<std::string> l) : attrs_(std::move(l)) {}
|
||||
CheckAttr(std::initializer_list<std::string> l) : attrs_(std::move(l)) {}
|
||||
virtual ~CheckAttr() = default;
|
||||
bool Check(const OpDesc &e) override {
|
||||
for (auto &a : attrs_) {
|
||||
|
|
|
@ -101,7 +101,8 @@ const auto NotTransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *
|
|||
};
|
||||
|
||||
bool IsAkgMatMul(size_t K, size_t M, size_t N) {
|
||||
if (K > MAX_PER_DIM_SHAPE || static_cast<int64_t>(M * N * K) >= MAX_ALL_SHAPE) {
|
||||
if (K > MAX_PER_DIM_SHAPE ||
|
||||
(static_cast<int64_t>(M) * static_cast<int64_t>(N) * static_cast<int64_t>(K)) >= MAX_ALL_SHAPE) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -39,7 +39,7 @@ void CustomJULIACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(EXCEPTION) << "Invalid file path, " << path << " does not exist.";
|
||||
}
|
||||
file_path_ = real_path.value();
|
||||
module_name_ = exec_info.substr(pos1 + 1, pos2 - pos1 - 1);
|
||||
module_name_ = exec_info.substr(pos1 + 1, (pos2 - pos1) - 1);
|
||||
func_name_ = exec_info.substr(pos2 + 1);
|
||||
|
||||
num_input_ = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -52,11 +52,11 @@ void CustomJULIACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
for (size_t i = 0; i < num_input_; i++) {
|
||||
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> in_shape_tmp;
|
||||
std::for_each(in_shape.begin(), in_shape.end(),
|
||||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.emplace_back(in_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
|
||||
type_list_.emplace_back(TypeIdToString(input_type_list[i], true));
|
||||
(void)std::for_each(in_shape.begin(), in_shape.end(),
|
||||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
ndims_.push_back(in_shape_tmp.size());
|
||||
shape_list_.push_back(in_shape_tmp);
|
||||
type_list_.push_back(TypeIdToString(input_type_list[i], true));
|
||||
}
|
||||
|
||||
num_output_ = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
|
@ -69,11 +69,11 @@ void CustomJULIACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
for (size_t i = 0; i < num_output_; i++) {
|
||||
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
|
||||
std::vector<int64_t> out_shape_tmp;
|
||||
std::for_each(out_shape.begin(), out_shape.end(),
|
||||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.emplace_back(out_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
|
||||
type_list_.emplace_back(TypeIdToString(output_type_list[i], true));
|
||||
(void)std::for_each(out_shape.begin(), out_shape.end(),
|
||||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
ndims_.push_back(out_shape_tmp.size());
|
||||
shape_list_.push_back(out_shape_tmp);
|
||||
type_list_.push_back(TypeIdToString(output_type_list[i], true));
|
||||
}
|
||||
|
||||
(void)std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
|
@ -91,9 +91,8 @@ bool CustomJULIACpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, cons
|
|||
for (size_t i = 0; i < num_output_; i++) {
|
||||
params.push_back(GetDeviceAddress<void>(outputs, i));
|
||||
}
|
||||
int nparam = SizeToInt(params.size());
|
||||
size_t nparam = params.size();
|
||||
JuliaAPI *julia = JuliaAPI::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(julia);
|
||||
if (!julia->Init()) {
|
||||
MS_LOG(EXCEPTION) << "Julia kernel[" << file_path_ << ":" << module_name_ << ":" << func_name_ << "] init failed.";
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ class CustomJULIACpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
protected:
|
||||
std::vector<std::vector<int64_t>> shape_list_;
|
||||
std::vector<int> ndims_;
|
||||
std::vector<size_t> ndims_;
|
||||
std::vector<std::string> type_list_;
|
||||
|
||||
std::vector<int64_t *> shapes_;
|
||||
|
|
|
@ -114,7 +114,6 @@ class JuliaAPI {
|
|||
handle_ = dlopen(kLibJulia, RTLD_LAZY | RTLD_LOCAL);
|
||||
if (!handle_) {
|
||||
MS_LOG(EXCEPTION) << dlerror();
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
return false;
|
||||
|
@ -133,8 +132,8 @@ class JuliaAPI {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Run(const std::string &file, const std::string &module, const std::string &func, int nparam,
|
||||
const std::vector<void *> ¶m, const std::vector<int> &ndims, const std::vector<int64_t *> &shapes,
|
||||
bool Run(const std::string &file, const std::string &module, const std::string &func, size_t nparam,
|
||||
const std::vector<void *> ¶m, const std::vector<size_t> &ndims, const std::vector<int64_t *> &shapes,
|
||||
const std::vector<const char *> &dtypes) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto NotRunning = [this]() { return !this->running_; };
|
||||
|
@ -162,7 +161,30 @@ class JuliaAPI {
|
|||
}
|
||||
|
||||
private:
|
||||
JuliaAPI() {
|
||||
JuliaAPI() { clear(); }
|
||||
~JuliaAPI() {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
if (handle_ != nullptr) {
|
||||
// ready to break the loop in the julia thread
|
||||
stop_ = true;
|
||||
// notify loop continue, which will stop the loop, then julia thread finished.
|
||||
c_.notify_one();
|
||||
// join the julia thread
|
||||
if (t_.joinable()) {
|
||||
try {
|
||||
t_.join();
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Try to join the julia thread failed! Error message is " << e.what();
|
||||
}
|
||||
}
|
||||
// close the handle of julia shared library
|
||||
(void)dlclose(handle_);
|
||||
}
|
||||
clear();
|
||||
#endif
|
||||
}
|
||||
|
||||
void clear() {
|
||||
handle_ = nullptr;
|
||||
jl_eval_string_ = nullptr;
|
||||
jl_get_global_ = nullptr;
|
||||
|
@ -177,20 +199,6 @@ class JuliaAPI {
|
|||
jl_ver_major_ = nullptr;
|
||||
jl_ver_minor_ = nullptr;
|
||||
}
|
||||
~JuliaAPI() {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
if (handle_ != nullptr) {
|
||||
// ready to break the loop in the julia thread
|
||||
stop_ = true;
|
||||
// notify loop continue, which will stop the loop, then julia thread finished.
|
||||
c_.notify_one();
|
||||
// join the julia thread
|
||||
t_.join();
|
||||
// close the handle of julia shared library
|
||||
dlclose(handle_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void Loop() {
|
||||
// keep the thread alive, for the julia runtime should run in one thread safely
|
||||
|
@ -267,22 +275,22 @@ class JuliaAPI {
|
|||
// Base.showerror(stderr, ex)
|
||||
std::vector<jl_value_t *> args{reinterpret_cast<jl_value_t *>(Core("stderr")), ex};
|
||||
constexpr size_t args_num = 2;
|
||||
JlEvalString("print(\"====================JULIA ERROR====================\\n\")");
|
||||
JlCall(showerror, &args[0], args_num);
|
||||
JlEvalString("print(\"\\n===================================================\\n\")");
|
||||
(void)JlEvalString("print(\"====================JULIA ERROR====================\\n\")");
|
||||
(void)JlCall(showerror, &args[0], args_num);
|
||||
(void)JlEvalString("print(\"\\n===================================================\\n\")");
|
||||
}
|
||||
|
||||
bool RunJuliaKernel() {
|
||||
if (!jl_file_caches_.count(file_)) {
|
||||
// include julia file
|
||||
JlEvalString("Base.include(Main, \"" + file_ + "\")");
|
||||
(void)JlEvalString("Base.include(Main, \"" + file_ + "\")");
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jl_file_caches_.insert(file_);
|
||||
(void)jl_file_caches_.insert(file_);
|
||||
}
|
||||
jl_module_t *jmod = nullptr;
|
||||
if (!jl_module_caches_.count(file_ + module_)) {
|
||||
// using julia module
|
||||
JlEvalString("using Main." + module_);
|
||||
(void)JlEvalString("using Main." + module_);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
jmod = reinterpret_cast<jl_module_t *>(JlEvalString("Main." + module_));
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
|
@ -301,54 +309,56 @@ class JuliaAPI {
|
|||
}
|
||||
// convert kernel inputs to julia type
|
||||
std::vector<jl_value_t *> args(nparam_);
|
||||
for (int i = 0; i < nparam_; i++) {
|
||||
for (size_t i = 0; i < nparam_; i++) {
|
||||
args[i] = reinterpret_cast<jl_value_t *>(GetJuliaArray(params_[i], ndims_[i], shapes_[i], dtypes_[i]));
|
||||
}
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
// call the julia function
|
||||
JlCall(jfunc, &args[0], nparam_);
|
||||
(void)JlCall(jfunc, &args[0], nparam_);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
JlAtexitHook(0);
|
||||
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
|
||||
return true;
|
||||
}
|
||||
|
||||
jl_value_t *JlEvalString(const std::string &str) { return jl_eval_string_(str.c_str()); }
|
||||
jl_value_t *JlEvalString(const std::string &str) const { return jl_eval_string_(str.c_str()); }
|
||||
|
||||
jl_value_t *JlGetGlobal(jl_module_t *m, jl_sym_t *var) { return jl_get_global_(m, var); }
|
||||
jl_value_t *JlGetGlobal(jl_module_t *m, jl_sym_t *var) const { return jl_get_global_(m, var); }
|
||||
|
||||
jl_sym_t *JlSymbol(const std::string &str) { return jl_symbol_(str.c_str()); }
|
||||
jl_sym_t *JlSymbol(const std::string &str) const { return jl_symbol_(str.c_str()); }
|
||||
|
||||
jl_value_t *JlCall(jl_function_t *f, jl_value_t **args, int32_t nargs) { return jl_call_(f, args, nargs); }
|
||||
jl_value_t *JlCall(jl_function_t *f, jl_value_t **args, size_t nargs) const {
|
||||
return jl_call_(f, args, static_cast<int32_t>(nargs));
|
||||
}
|
||||
|
||||
jl_value_t *JlExceptionOccurred(void) { return jl_exception_occurred_(); }
|
||||
jl_value_t *JlExceptionOccurred() const { return jl_exception_occurred_(); }
|
||||
|
||||
void JlAtexitHook(int status) { return jl_atexit_hook_(status); }
|
||||
void JlAtexitHook(int status) const { return jl_atexit_hook_(status); }
|
||||
|
||||
void JlInit(void) { return jl_init__threading_(); }
|
||||
void JlInit(void) const { return jl_init__threading_(); }
|
||||
|
||||
jl_value_t *JlApplyArrayType(jl_value_t *type, size_t dim) { return jl_apply_array_type_(type, dim); }
|
||||
jl_value_t *JlApplyArrayType(jl_value_t *type, size_t dim) const { return jl_apply_array_type_(type, dim); }
|
||||
|
||||
jl_array_t *JlPtrToArray(jl_value_t *atype, void *data, jl_value_t *dims, int own_buffer) {
|
||||
jl_array_t *JlPtrToArray(jl_value_t *atype, void *data, jl_value_t *dims, int own_buffer) const {
|
||||
return jl_ptr_to_array_(atype, data, dims, own_buffer);
|
||||
}
|
||||
|
||||
std::string JlTypeOfStr(jl_value_t *v) { return jl_typeof_str_(v); }
|
||||
std::string JlTypeOfStr(jl_value_t *v) const { return jl_typeof_str_(v); }
|
||||
|
||||
int JlVerMajor() { return jl_ver_major_(); }
|
||||
int JlVerMajor() const { return jl_ver_major_(); }
|
||||
|
||||
int JlVerMinor() { return jl_ver_minor_(); }
|
||||
int JlVerMinor() const { return jl_ver_minor_(); }
|
||||
|
||||
jl_function_t *JlGetFunction(jl_module_t *m, const std::string &name) {
|
||||
jl_function_t *JlGetFunction(jl_module_t *m, const std::string &name) const {
|
||||
return reinterpret_cast<jl_function_t *>(JlGetGlobal(m, JlSymbol(name)));
|
||||
}
|
||||
|
||||
jl_value_t *Core(const std::string &name) {
|
||||
jl_value_t *Core(const std::string &name) const {
|
||||
jl_module_t *jl_core_module = reinterpret_cast<jl_module_t *>(JlEvalString("Main.Core"));
|
||||
return JlGetGlobal(jl_core_module, JlSymbol(name.c_str()));
|
||||
}
|
||||
|
||||
jl_datatype_t *GetType(const std::string &dtypes) {
|
||||
jl_datatype_t *GetType(const std::string &dtypes) const {
|
||||
jl_datatype_t *type = reinterpret_cast<jl_datatype_t *>(Core("Float32"));
|
||||
std::unordered_map<std::string, std::string> m{
|
||||
{"float16", "Float16"}, {"float32", "Float32"}, {"float64", "Float64"}, {"int8", "Int8"},
|
||||
|
@ -360,9 +370,9 @@ class JuliaAPI {
|
|||
return type;
|
||||
}
|
||||
|
||||
jl_array_t *GetJuliaArray(void *params, int ndims, int64_t *shapes, const std::string &dtypes) {
|
||||
jl_array_t *GetJuliaArray(void *params, size_t ndims, const int64_t *shapes, const std::string &dtypes) {
|
||||
std::string shape_str = "(";
|
||||
for (int j = 0; j < ndims; j++) {
|
||||
for (size_t j = 0; j < ndims; j++) {
|
||||
shape_str += std::to_string(shapes[j]);
|
||||
shape_str += ",";
|
||||
}
|
||||
|
@ -373,7 +383,6 @@ class JuliaAPI {
|
|||
return JlPtrToArray(array_type, params, shape, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
// the thread which used to call julia func, will be created by first julia kernel,
|
||||
// and will always exist until the JuliaAPI instance gone.
|
||||
std::thread t_;
|
||||
|
@ -391,11 +400,11 @@ class JuliaAPI {
|
|||
|
||||
// julia kernel's inputs
|
||||
std::vector<void *> params_;
|
||||
int nparam_{0};
|
||||
size_t nparam_{0};
|
||||
std::string file_;
|
||||
std::string module_;
|
||||
std::string func_;
|
||||
std::vector<int> ndims_;
|
||||
std::vector<size_t> ndims_;
|
||||
std::vector<int64_t *> shapes_;
|
||||
std::vector<const char *> dtypes_;
|
||||
|
||||
|
|
Loading…
Reference in New Issue