!44138 fix kernel executor compile error

Merge pull request !44138 from 王平安/kernel_executor_api
This commit is contained in:
i-robot 2022-10-20 02:27:14 +00:00 committed by Gitee
commit f696a828cf
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 12 additions and 15 deletions

View File

@ -47,10 +47,6 @@ std::unordered_map<std::string, int> ops_output_num = {
} // namespace
KernelExecutorImpl::~KernelExecutorImpl() {
if (context_ != nullptr) {
delete context_;
context_ = nullptr;
}
FreeAllResource();
inputs_.clear();
}
@ -76,13 +72,13 @@ Status KernelExecutorImpl::Build(const std::shared_ptr<ops::BaseOperator> &op, c
}
if (status != kSuccess) {
MS_LOG(ERROR) << "get kernel error.";
MS_LOG(ERROR) << "get cpu kernel error.";
FreeAllResource();
return status;
}
int ret = kernel_->Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "kernel Prepare error.";
MS_LOG(ERROR) << "cpu kernel Prepare error.";
FreeAllResource();
return static_cast<StatusCode>(ret);
}
@ -102,13 +98,13 @@ Status KernelExecutorImpl::Build(const std::shared_ptr<ops::Custom> &op, const s
InitTensors(inputs, output_num);
status = GetCustomKernel(ms_context);
if (status != kSuccess) {
MS_LOG(ERROR) << "get kernel error.";
MS_LOG(ERROR) << "get custom kernel error.";
FreeAllResource();
return status;
}
int ret = kernel_->Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "kernel Prepare error.";
MS_LOG(ERROR) << "custom kernel Prepare error.";
FreeAllResource();
return static_cast<StatusCode>(ret);
}
@ -248,16 +244,16 @@ Status KernelExecutorImpl::GetCustomKernel(const std::shared_ptr<Context> &ms_co
if (!device.provider_.empty() && !device.provider_device_.empty()) {
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, NHWC, prim_type_,
device.provider_device_, device.provider_};
get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_, ms_context.get(),
desc, nullptr, &kernel_, primitive_);
get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(
inputs_, outputs_, context_.get(), ms_context.get(), desc, nullptr, &kernel_, primitive_);
}
}
// find kernel only match arch and data_type
if (get_kernel != RET_OK) {
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, NHWC, prim_type_, "", ""};
get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_, ms_context.get(), desc,
nullptr, &kernel_, primitive_);
get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_.get(), ms_context.get(),
desc, nullptr, &kernel_, primitive_);
}
// if found kernel, do infershape
@ -276,8 +272,8 @@ Status KernelExecutorImpl::GetCpuKernel(const std::shared_ptr<Context> &ms_conte
}
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, NHWC, prim_type_};
int get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_, ms_context.get(),
desc, parameter_, &kernel_);
int get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_.get(),
ms_context.get(), desc, parameter_, &kernel_);
if (get_kernel == RET_OK) {
int ret = KernelInferShape(inputs_, outputs_, parameter_);
return static_cast<StatusCode>(ret);
@ -304,6 +300,7 @@ void KernelExecutorImpl::InitTensors(const std::vector<MSTensor> &inputs, const
if (output_tensor == nullptr) {
MS_LOG(ERROR) << "Failed to allocate tensor.";
}
output_tensor->set_category(lite::Category::VAR);
if (data_type_ == kNumberTypeInt8) {
Int8TensorAddQuantParam(output_tensor);
}

View File

@ -51,7 +51,7 @@ class KernelExecutorImpl {
const schema::Primitive *primitive_ = nullptr;
int prim_type_;
OpParameter *parameter_ = nullptr;
lite::InnerContext *context_ = nullptr;
std::shared_ptr<lite::InnerContext> context_ = nullptr;
TypeId data_type_;
kernel::KernelExec *kernel_ = nullptr;
std::vector<lite::Tensor *> inputs_;