Merge pull request !31510 from r1chardf1d0/clean2
This commit is contained in:
i-robot 2022-03-21 01:53:31 +00:00 committed by Gitee
commit f5010ec8b8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 71 additions and 62 deletions

View File

@ -81,7 +81,7 @@ class CheckAllFormatsSame : public Validator {
class CheckAttr : public Validator {
public:
CheckAttr(const std::initializer_list<std::string> l) : attrs_(std::move(l)) {}
CheckAttr(std::initializer_list<std::string> l) : attrs_(std::move(l)) {}
virtual ~CheckAttr() = default;
bool Check(const OpDesc &e) override {
for (auto &a : attrs_) {

View File

@ -101,7 +101,8 @@ const auto NotTransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *
};
bool IsAkgMatMul(size_t K, size_t M, size_t N) {
if (K > MAX_PER_DIM_SHAPE || static_cast<int64_t>(M * N * K) >= MAX_ALL_SHAPE) {
if (K > MAX_PER_DIM_SHAPE ||
(static_cast<int64_t>(M) * static_cast<int64_t>(N) * static_cast<int64_t>(K)) >= MAX_ALL_SHAPE) {
return false;
}
return true;

View File

@ -39,7 +39,7 @@ void CustomJULIACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Invalid file path, " << path << " does not exist.";
}
file_path_ = real_path.value();
module_name_ = exec_info.substr(pos1 + 1, pos2 - pos1 - 1);
module_name_ = exec_info.substr(pos1 + 1, (pos2 - pos1) - 1);
func_name_ = exec_info.substr(pos2 + 1);
num_input_ = common::AnfAlgo::GetInputTensorNum(kernel_node);
@ -52,11 +52,11 @@ void CustomJULIACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
for (size_t i = 0; i < num_input_; i++) {
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
std::vector<int64_t> in_shape_tmp;
std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
shape_list_.emplace_back(in_shape_tmp);
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
type_list_.emplace_back(TypeIdToString(input_type_list[i], true));
(void)std::for_each(in_shape.begin(), in_shape.end(),
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
ndims_.push_back(in_shape_tmp.size());
shape_list_.push_back(in_shape_tmp);
type_list_.push_back(TypeIdToString(input_type_list[i], true));
}
num_output_ = common::AnfAlgo::GetOutputTensorNum(kernel_node);
@ -69,11 +69,11 @@ void CustomJULIACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
for (size_t i = 0; i < num_output_; i++) {
std::vector<size_t> out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
std::vector<int64_t> out_shape_tmp;
std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
shape_list_.emplace_back(out_shape_tmp);
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
type_list_.emplace_back(TypeIdToString(output_type_list[i], true));
(void)std::for_each(out_shape.begin(), out_shape.end(),
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
ndims_.push_back(out_shape_tmp.size());
shape_list_.push_back(out_shape_tmp);
type_list_.push_back(TypeIdToString(output_type_list[i], true));
}
(void)std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
@ -91,9 +91,8 @@ bool CustomJULIACpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, cons
for (size_t i = 0; i < num_output_; i++) {
params.push_back(GetDeviceAddress<void>(outputs, i));
}
int nparam = SizeToInt(params.size());
size_t nparam = params.size();
JuliaAPI *julia = JuliaAPI::GetInstance();
MS_EXCEPTION_IF_NULL(julia);
if (!julia->Init()) {
MS_LOG(EXCEPTION) << "Julia kernel[" << file_path_ << ":" << module_name_ << ":" << func_name_ << "] init failed.";
}

View File

@ -34,7 +34,7 @@ class CustomJULIACpuKernelMod : public NativeCpuKernelMod {
protected:
std::vector<std::vector<int64_t>> shape_list_;
std::vector<int> ndims_;
std::vector<size_t> ndims_;
std::vector<std::string> type_list_;
std::vector<int64_t *> shapes_;

View File

@ -114,7 +114,6 @@ class JuliaAPI {
handle_ = dlopen(kLibJulia, RTLD_LAZY | RTLD_LOCAL);
if (!handle_) {
MS_LOG(EXCEPTION) << dlerror();
return false;
}
#else
return false;
@ -133,8 +132,8 @@ class JuliaAPI {
return true;
}
bool Run(const std::string &file, const std::string &module, const std::string &func, int nparam,
const std::vector<void *> &param, const std::vector<int> &ndims, const std::vector<int64_t *> &shapes,
bool Run(const std::string &file, const std::string &module, const std::string &func, size_t nparam,
const std::vector<void *> &param, const std::vector<size_t> &ndims, const std::vector<int64_t *> &shapes,
const std::vector<const char *> &dtypes) {
std::unique_lock<std::mutex> lock(mutex_);
auto NotRunning = [this]() { return !this->running_; };
@ -162,7 +161,30 @@ class JuliaAPI {
}
private:
JuliaAPI() {
JuliaAPI() { clear(); }
~JuliaAPI() {
#if !defined(_WIN32) && !defined(_WIN64)
if (handle_ != nullptr) {
// ready to break the loop in the julia thread
stop_ = true;
// notify loop continue, which will stop the loop, then julia thread finished.
c_.notify_one();
// join the julia thread
if (t_.joinable()) {
try {
t_.join();
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Try to join the julia thread failed! Error message is " << e.what();
}
}
// close the handle of julia shared library
(void)dlclose(handle_);
}
clear();
#endif
}
void clear() {
handle_ = nullptr;
jl_eval_string_ = nullptr;
jl_get_global_ = nullptr;
@ -177,20 +199,6 @@ class JuliaAPI {
jl_ver_major_ = nullptr;
jl_ver_minor_ = nullptr;
}
~JuliaAPI() {
#if !defined(_WIN32) && !defined(_WIN64)
if (handle_ != nullptr) {
// ready to break the loop in the julia thread
stop_ = true;
// notify loop continue, which will stop the loop, then julia thread finished.
c_.notify_one();
// join the julia thread
t_.join();
// close the handle of julia shared library
dlclose(handle_);
}
#endif
}
static void Loop() {
// keep the thread alive, for the julia runtime should run in one thread safely
@ -267,22 +275,22 @@ class JuliaAPI {
// Base.showerror(stderr, ex)
std::vector<jl_value_t *> args{reinterpret_cast<jl_value_t *>(Core("stderr")), ex};
constexpr size_t args_num = 2;
JlEvalString("print(\"====================JULIA ERROR====================\\n\")");
JlCall(showerror, &args[0], args_num);
JlEvalString("print(\"\\n===================================================\\n\")");
(void)JlEvalString("print(\"====================JULIA ERROR====================\\n\")");
(void)JlCall(showerror, &args[0], args_num);
(void)JlEvalString("print(\"\\n===================================================\\n\")");
}
bool RunJuliaKernel() {
if (!jl_file_caches_.count(file_)) {
// include julia file
JlEvalString("Base.include(Main, \"" + file_ + "\")");
(void)JlEvalString("Base.include(Main, \"" + file_ + "\")");
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
jl_file_caches_.insert(file_);
(void)jl_file_caches_.insert(file_);
}
jl_module_t *jmod = nullptr;
if (!jl_module_caches_.count(file_ + module_)) {
// using julia module
JlEvalString("using Main." + module_);
(void)JlEvalString("using Main." + module_);
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
jmod = reinterpret_cast<jl_module_t *>(JlEvalString("Main." + module_));
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
@ -301,54 +309,56 @@ class JuliaAPI {
}
// convert kernel inputs to julia type
std::vector<jl_value_t *> args(nparam_);
for (int i = 0; i < nparam_; i++) {
for (size_t i = 0; i < nparam_; i++) {
args[i] = reinterpret_cast<jl_value_t *>(GetJuliaArray(params_[i], ndims_[i], shapes_[i], dtypes_[i]));
}
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
// call the julia function
JlCall(jfunc, &args[0], nparam_);
(void)JlCall(jfunc, &args[0], nparam_);
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
JlAtexitHook(0);
RETURN_FALSE_IF_GET_JULIA_EXCEPTION();
return true;
}
jl_value_t *JlEvalString(const std::string &str) { return jl_eval_string_(str.c_str()); }
jl_value_t *JlEvalString(const std::string &str) const { return jl_eval_string_(str.c_str()); }
jl_value_t *JlGetGlobal(jl_module_t *m, jl_sym_t *var) { return jl_get_global_(m, var); }
jl_value_t *JlGetGlobal(jl_module_t *m, jl_sym_t *var) const { return jl_get_global_(m, var); }
jl_sym_t *JlSymbol(const std::string &str) { return jl_symbol_(str.c_str()); }
jl_sym_t *JlSymbol(const std::string &str) const { return jl_symbol_(str.c_str()); }
jl_value_t *JlCall(jl_function_t *f, jl_value_t **args, int32_t nargs) { return jl_call_(f, args, nargs); }
jl_value_t *JlCall(jl_function_t *f, jl_value_t **args, size_t nargs) const {
return jl_call_(f, args, static_cast<int32_t>(nargs));
}
jl_value_t *JlExceptionOccurred(void) { return jl_exception_occurred_(); }
jl_value_t *JlExceptionOccurred() const { return jl_exception_occurred_(); }
void JlAtexitHook(int status) { return jl_atexit_hook_(status); }
void JlAtexitHook(int status) const { return jl_atexit_hook_(status); }
void JlInit(void) { return jl_init__threading_(); }
void JlInit(void) const { return jl_init__threading_(); }
jl_value_t *JlApplyArrayType(jl_value_t *type, size_t dim) { return jl_apply_array_type_(type, dim); }
jl_value_t *JlApplyArrayType(jl_value_t *type, size_t dim) const { return jl_apply_array_type_(type, dim); }
jl_array_t *JlPtrToArray(jl_value_t *atype, void *data, jl_value_t *dims, int own_buffer) {
jl_array_t *JlPtrToArray(jl_value_t *atype, void *data, jl_value_t *dims, int own_buffer) const {
return jl_ptr_to_array_(atype, data, dims, own_buffer);
}
std::string JlTypeOfStr(jl_value_t *v) { return jl_typeof_str_(v); }
std::string JlTypeOfStr(jl_value_t *v) const { return jl_typeof_str_(v); }
int JlVerMajor() { return jl_ver_major_(); }
int JlVerMajor() const { return jl_ver_major_(); }
int JlVerMinor() { return jl_ver_minor_(); }
int JlVerMinor() const { return jl_ver_minor_(); }
jl_function_t *JlGetFunction(jl_module_t *m, const std::string &name) {
jl_function_t *JlGetFunction(jl_module_t *m, const std::string &name) const {
return reinterpret_cast<jl_function_t *>(JlGetGlobal(m, JlSymbol(name)));
}
jl_value_t *Core(const std::string &name) {
jl_value_t *Core(const std::string &name) const {
jl_module_t *jl_core_module = reinterpret_cast<jl_module_t *>(JlEvalString("Main.Core"));
return JlGetGlobal(jl_core_module, JlSymbol(name.c_str()));
}
jl_datatype_t *GetType(const std::string &dtypes) {
jl_datatype_t *GetType(const std::string &dtypes) const {
jl_datatype_t *type = reinterpret_cast<jl_datatype_t *>(Core("Float32"));
std::unordered_map<std::string, std::string> m{
{"float16", "Float16"}, {"float32", "Float32"}, {"float64", "Float64"}, {"int8", "Int8"},
@ -360,9 +370,9 @@ class JuliaAPI {
return type;
}
jl_array_t *GetJuliaArray(void *params, int ndims, int64_t *shapes, const std::string &dtypes) {
jl_array_t *GetJuliaArray(void *params, size_t ndims, const int64_t *shapes, const std::string &dtypes) {
std::string shape_str = "(";
for (int j = 0; j < ndims; j++) {
for (size_t j = 0; j < ndims; j++) {
shape_str += std::to_string(shapes[j]);
shape_str += ",";
}
@ -373,7 +383,6 @@ class JuliaAPI {
return JlPtrToArray(array_type, params, shape, 0);
}
private:
// the thread which used to call julia func, will be created by first julia kernel,
// and will always exist until the JuliaAPI instance gone.
std::thread t_;
@ -391,11 +400,11 @@ class JuliaAPI {
// julia kernel's inputs
std::vector<void *> params_;
int nparam_{0};
size_t nparam_{0};
std::string file_;
std::string module_;
std::string func_;
std::vector<int> ndims_;
std::vector<size_t> ndims_;
std::vector<int64_t *> shapes_;
std::vector<const char *> dtypes_;