mindspore/third_party/patch/predict/0001-RetBugFix-CustomRuntim...

204 lines
8.9 KiB
Diff

diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h
index d668984..3676a61 100644
--- a/include/tvm/runtime/registry.h
+++ b/include/tvm/runtime/registry.h
@@ -319,6 +319,19 @@ class Registry {
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>()
+/*
+ * Macro transfer TVM runtime API to custom runtime API
+ */
+#define TVM_RT_FUNC_TRANS(OrigFuncStr) ({ \
+ const runtime::PackedFunc* trans_func = runtime::Registry::Get("codegen.GetTransRTFunc");\
+ const char* dst_func_str = nullptr; \
+ if( trans_func != nullptr){ \
+ dst_func_str = ((*trans_func)(OrigFuncStr)).ptr<const char>(); \
+ }else{ \
+ dst_func_str = OrigFuncStr; \
+ } \
+ dst_func_str; \
+})
} // namespace runtime
} // namespace tvm
diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc
index 0ba0c58..2850ad4 100644
--- a/src/codegen/llvm/codegen_cpu.cc
+++ b/src/codegen/llvm/codegen_cpu.cc
@@ -99,26 +99,26 @@ void CodeGenCPU::Init(const std::string& module_name,
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
- llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendRegisterSystemLibSymbol"), module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
if (dynamic_lookup || system_lib) {
f_tvm_func_call_ = llvm::Function::Create(
ftype_tvm_func_call_,
- llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMFuncCall"), module_.get());
f_tvm_get_func_from_env_ = llvm::Function::Create(
ftype_tvm_get_func_from_env_,
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_,
- llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMAPISetLastError"), module_.get());
f_tvm_parallel_launch_ = llvm::Function::Create(
ftype_tvm_parallel_launch_,
- llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelLaunch"), module_.get());
f_tvm_parallel_barrier_ = llvm::Function::Create(
ftype_tvm_parallel_barrier_,
- llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
+ llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelBarrier"), module_.get());
}
this->InitGlobalContext(dynamic_lookup);
}
@@ -461,11 +461,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
}
std::swap(function_, fcompute);
std::swap(new_vmap, var_map_);
+ std::stack<bool*> br_ret_flg;
+ std::swap(br_ret_flg, br_ret_flg_);
BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
+ std::swap(br_ret_flg, br_ret_flg_);
std::swap(new_vmap, var_map_);
std::swap(function_, fcompute);
builder_->SetInsertPoint(compute_call_end);
@@ -542,9 +545,12 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
std::swap(function_, f);
std::swap(parallel_env_, par_env);
std::swap(var_map_, new_vmap);
+ std::stack<bool*> br_ret_flg;
+ std::swap(br_ret_flg, br_ret_flg_);
this->VisitStmt(body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
+ std::swap(br_ret_flg, br_ret_flg_);
std::swap(var_map_, new_vmap);
std::swap(parallel_env_, par_env);
std::swap(function_, f);
@@ -794,7 +800,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
return CreateStaticHandle();
} else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
- builder_->CreateRet(ConstInt32(-1));
+ llvm::Value* pRetCode = (op->args.size() == 0) ? ConstInt32(-1) : MakeValue(op->args[0]);
+ builder_->CreateRet(pRetCode);
+ CodeGenLLVM::SetRetTrFlg(true);
return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc
index 2cff88b..e26812d 100644
--- a/src/codegen/llvm/codegen_llvm.cc
+++ b/src/codegen/llvm/codegen_llvm.cc
@@ -1110,23 +1110,37 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
*ctx_, "if_then", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
+ // define ret terminitor exist flg for this Stmt
+ bool cur_br_ret_flg = false;
+ br_ret_flg_.push(&cur_br_ret_flg);
if (op->else_case.defined()) {
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
builder_->CreateCondBr(cond, then_block, else_block);
builder_->SetInsertPoint(then_block);
+ cur_br_ret_flg = false;
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
+ if ( !cur_br_ret_flg ){
+ builder_->CreateBr(end_block);
+ }
builder_->SetInsertPoint(else_block);
+ cur_br_ret_flg = false;
this->VisitStmt(op->else_case);
- builder_->CreateBr(end_block);
+ if ( !cur_br_ret_flg ){
+ builder_->CreateBr(end_block);
+ }
} else {
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
builder_->SetInsertPoint(then_block);
+ cur_br_ret_flg = false;
this->VisitStmt(op->then_case);
- builder_->CreateBr(end_block);
+ if ( !cur_br_ret_flg ){
+ builder_->CreateBr(end_block);
+ }
}
builder_->SetInsertPoint(end_block);
+ br_ret_flg_.pop();
}
diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h
index b7d091b..6fba863 100644
--- a/src/codegen/llvm/codegen_llvm.h
+++ b/src/codegen/llvm/codegen_llvm.h
@@ -143,6 +143,11 @@ class CodeGenLLVM :
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
+ //Set IfThelElse branch exist Return flg
+ void SetRetTrFlg(bool RetFlg){
+ if( !br_ret_flg_.empty() )
+ *(br_ret_flg_.top()) = RetFlg;
+ }
protected:
/*! \brief The storage information */
@@ -304,6 +309,12 @@ class CodeGenLLVM :
* initializes file and compilation_unit_ to TVM defaults.
*/
static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
+
+ /*
+ * IfThenElse stmt branch return flg store stack
+ * if a branch already return, can't add br terminator again
+ */
+ std::stack<bool*> br_ret_flg_;
};
} // namespace codegen
} // namespace tvm
diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
index e73956c..3a7b46c 100644
--- a/src/pass/lower_tvm_builtin.cc
+++ b/src/pass/lower_tvm_builtin.cc
@@ -104,7 +104,7 @@ class BuiltinLower : public IRMutator {
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
- intrinsic::tvm_throw_last_error, {},
+ intrinsic::tvm_throw_last_error, {(Int(32), 1001)},
Call::Intrinsic));
Stmt body = Block::make(
@@ -117,7 +117,7 @@ class BuiltinLower : public IRMutator {
Stmt alloca = LetStmt::make(
op->buffer_var,
Call::make(op->buffer_var.type(),
- "TVMBackendAllocWorkspace",
+ TVM_RT_FUNC_TRANS("TVMBackendAllocWorkspace"),
{cast(Int(32), device_type_),
cast(Int(32), device_id_),
cast(UInt(64), total_bytes),
@@ -127,7 +127,7 @@ class BuiltinLower : public IRMutator {
body);
Expr free_op = Call::make(Int(32),
- "TVMBackendFreeWorkspace",
+ TVM_RT_FUNC_TRANS("TVMBackendFreeWorkspace"),
{cast(Int(32), device_type_),
cast(Int(32), device_id_),
op->buffer_var},