forked from mindspore-Ecosystem/mindspore
204 lines
8.9 KiB
Diff
204 lines
8.9 KiB
Diff
diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h
|
|
index d668984..3676a61 100644
|
|
--- a/include/tvm/runtime/registry.h
|
|
+++ b/include/tvm/runtime/registry.h
|
|
@@ -319,6 +319,19 @@ class Registry {
|
|
#define TVM_REGISTER_EXT_TYPE(T) \
|
|
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
|
|
::tvm::runtime::ExtTypeVTable::Register_<T>()
|
|
+/*
|
|
+ * Macro transfer TVM runtime API to custom runtime API
|
|
+ */
|
|
+#define TVM_RT_FUNC_TRANS(OrigFuncStr) ({ \
|
|
+ const runtime::PackedFunc* trans_func = runtime::Registry::Get("codegen.GetTransRTFunc");\
|
|
+ const char* dst_func_str = nullptr; \
|
|
+ if( trans_func != nullptr){ \
|
|
+ dst_func_str = ((*trans_func)(OrigFuncStr)).ptr<const char>(); \
|
|
+ }else{ \
|
|
+ dst_func_str = OrigFuncStr; \
|
|
+ } \
|
|
+ dst_func_str; \
|
|
+})
|
|
|
|
} // namespace runtime
|
|
} // namespace tvm
|
|
diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc
|
|
index 0ba0c58..2850ad4 100644
|
|
--- a/src/codegen/llvm/codegen_cpu.cc
|
|
+++ b/src/codegen/llvm/codegen_cpu.cc
|
|
@@ -99,26 +99,26 @@ void CodeGenCPU::Init(const std::string& module_name,
|
|
// We will need this in environment for backward registration.
|
|
f_tvm_register_system_symbol_ = llvm::Function::Create(
|
|
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
|
|
- llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
|
|
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendRegisterSystemLibSymbol"), module_.get());
|
|
} else {
|
|
f_tvm_register_system_symbol_ = nullptr;
|
|
}
|
|
if (dynamic_lookup || system_lib) {
|
|
f_tvm_func_call_ = llvm::Function::Create(
|
|
ftype_tvm_func_call_,
|
|
- llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
|
|
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMFuncCall"), module_.get());
|
|
f_tvm_get_func_from_env_ = llvm::Function::Create(
|
|
ftype_tvm_get_func_from_env_,
|
|
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
|
|
f_tvm_api_set_last_error_ = llvm::Function::Create(
|
|
ftype_tvm_api_set_last_error_,
|
|
- llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
|
|
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMAPISetLastError"), module_.get());
|
|
f_tvm_parallel_launch_ = llvm::Function::Create(
|
|
ftype_tvm_parallel_launch_,
|
|
- llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
|
|
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelLaunch"), module_.get());
|
|
f_tvm_parallel_barrier_ = llvm::Function::Create(
|
|
ftype_tvm_parallel_barrier_,
|
|
- llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
|
|
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelBarrier"), module_.get());
|
|
}
|
|
this->InitGlobalContext(dynamic_lookup);
|
|
}
|
|
@@ -461,11 +461,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
|
|
}
|
|
std::swap(function_, fcompute);
|
|
std::swap(new_vmap, var_map_);
|
|
+ std::stack<bool*> br_ret_flg;
|
|
+ std::swap(br_ret_flg, br_ret_flg_);
|
|
BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
|
|
builder_->SetInsertPoint(compute_entry);
|
|
this->VisitStmt(op->body);
|
|
builder_->CreateRet(ConstInt32(0));
|
|
// swap the var map back, now we are back on track.
|
|
+ std::swap(br_ret_flg, br_ret_flg_);
|
|
std::swap(new_vmap, var_map_);
|
|
std::swap(function_, fcompute);
|
|
builder_->SetInsertPoint(compute_call_end);
|
|
@@ -542,9 +545,12 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
|
|
std::swap(function_, f);
|
|
std::swap(parallel_env_, par_env);
|
|
std::swap(var_map_, new_vmap);
|
|
+ std::stack<bool*> br_ret_flg;
|
|
+ std::swap(br_ret_flg, br_ret_flg_);
|
|
this->VisitStmt(body);
|
|
builder_->CreateRet(ConstInt32(0));
|
|
// swap the var map back, now we are back on track.
|
|
+ std::swap(br_ret_flg, br_ret_flg_);
|
|
std::swap(var_map_, new_vmap);
|
|
std::swap(parallel_env_, par_env);
|
|
std::swap(function_, f);
|
|
@@ -794,7 +800,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
|
|
} else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
|
|
return CreateStaticHandle();
|
|
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
|
|
- builder_->CreateRet(ConstInt32(-1));
|
|
+ llvm::Value* pRetCode = (op->args.size() == 0) ? ConstInt32(-1) : MakeValue(op->args[0]);
|
|
+ builder_->CreateRet(pRetCode);
|
|
+ CodeGenLLVM::SetRetTrFlg(true);
|
|
return ConstInt32(-1);
|
|
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
|
|
CHECK_EQ(op->args.size(), 3U);
|
|
diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc
|
|
index 2cff88b..e26812d 100644
|
|
--- a/src/codegen/llvm/codegen_llvm.cc
|
|
+++ b/src/codegen/llvm/codegen_llvm.cc
|
|
@@ -1110,23 +1110,37 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
|
|
*ctx_, "if_then", function_);
|
|
BasicBlock* end_block = BasicBlock::Create(
|
|
*ctx_, "if_end", function_);
|
|
+ // define ret terminitor exist flg for this Stmt
|
|
+ bool cur_br_ret_flg = false;
|
|
+ br_ret_flg_.push(&cur_br_ret_flg);
|
|
if (op->else_case.defined()) {
|
|
BasicBlock* else_block = BasicBlock::Create(
|
|
*ctx_, "if_else", function_);
|
|
builder_->CreateCondBr(cond, then_block, else_block);
|
|
builder_->SetInsertPoint(then_block);
|
|
+ cur_br_ret_flg = false;
|
|
this->VisitStmt(op->then_case);
|
|
builder_->CreateBr(end_block);
|
|
+ if ( !cur_br_ret_flg ){
|
|
+ builder_->CreateBr(end_block);
|
|
+ }
|
|
builder_->SetInsertPoint(else_block);
|
|
+ cur_br_ret_flg = false;
|
|
this->VisitStmt(op->else_case);
|
|
- builder_->CreateBr(end_block);
|
|
+ if ( !cur_br_ret_flg ){
|
|
+ builder_->CreateBr(end_block);
|
|
+ }
|
|
} else {
|
|
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
|
|
builder_->SetInsertPoint(then_block);
|
|
+ cur_br_ret_flg = false;
|
|
this->VisitStmt(op->then_case);
|
|
- builder_->CreateBr(end_block);
|
|
+ if ( !cur_br_ret_flg ){
|
|
+ builder_->CreateBr(end_block);
|
|
+ }
|
|
}
|
|
builder_->SetInsertPoint(end_block);
|
|
+ br_ret_flg_.pop();
|
|
}
|
|
|
|
|
|
diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h
|
|
index b7d091b..6fba863 100644
|
|
--- a/src/codegen/llvm/codegen_llvm.h
|
|
+++ b/src/codegen/llvm/codegen_llvm.h
|
|
@@ -143,6 +143,11 @@ class CodeGenLLVM :
|
|
void VisitStmt_(const Block* op) override;
|
|
void VisitStmt_(const Evaluate* op) override;
|
|
void VisitStmt_(const ProducerConsumer* op) override;
|
|
+ //Set IfThelElse branch exist Return flg
|
|
+ void SetRetTrFlg(bool RetFlg){
|
|
+ if( !br_ret_flg_.empty() )
|
|
+ *(br_ret_flg_.top()) = RetFlg;
|
|
+ }
|
|
|
|
protected:
|
|
/*! \brief The storage information */
|
|
@@ -304,6 +309,12 @@ class CodeGenLLVM :
|
|
* initializes file and compilation_unit_ to TVM defaults.
|
|
*/
|
|
static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
|
|
+
|
|
+ /*
|
|
+ * IfThenElse stmt branch return flg store stack
|
|
+ * if a branch already return, can't add br terminator again
|
|
+ */
|
|
+ std::stack<bool*> br_ret_flg_;
|
|
};
|
|
} // namespace codegen
|
|
} // namespace tvm
|
|
diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
|
|
index e73956c..3a7b46c 100644
|
|
--- a/src/pass/lower_tvm_builtin.cc
|
|
+++ b/src/pass/lower_tvm_builtin.cc
|
|
@@ -104,7 +104,7 @@ class BuiltinLower : public IRMutator {
|
|
CHECK(device_type_.defined()) << "Unknown device type in current IR";
|
|
CHECK(device_id_.defined()) << "Unknown device id in current IR";
|
|
Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
|
|
- intrinsic::tvm_throw_last_error, {},
|
|
+ intrinsic::tvm_throw_last_error, {(Int(32), 1001)},
|
|
Call::Intrinsic));
|
|
|
|
Stmt body = Block::make(
|
|
@@ -117,7 +117,7 @@ class BuiltinLower : public IRMutator {
|
|
Stmt alloca = LetStmt::make(
|
|
op->buffer_var,
|
|
Call::make(op->buffer_var.type(),
|
|
- "TVMBackendAllocWorkspace",
|
|
+ TVM_RT_FUNC_TRANS("TVMBackendAllocWorkspace"),
|
|
{cast(Int(32), device_type_),
|
|
cast(Int(32), device_id_),
|
|
cast(UInt(64), total_bytes),
|
|
@@ -127,7 +127,7 @@ class BuiltinLower : public IRMutator {
|
|
body);
|
|
|
|
Expr free_op = Call::make(Int(32),
|
|
- "TVMBackendFreeWorkspace",
|
|
+ TVM_RT_FUNC_TRANS("TVMBackendFreeWorkspace"),
|
|
{cast(Int(32), device_type_),
|
|
cast(Int(32), device_id_),
|
|
op->buffer_var},
|