!34714 kernel_executor update the called funtions.
Merge pull request !34714 from 王平安/single_op_bug
This commit is contained in:
commit
b119815f4c
|
@ -626,7 +626,7 @@ if(PLATFORM_ARM64)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/src/runtime/cxx_api/kernel_executor/kernel_executor.h DESTINATION
|
||||
${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -1005,7 +1005,7 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/src/runtime/cxx_api/kernel_executor/kernel_executor.h DESTINATION
|
||||
${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
|
|
@ -15,9 +15,9 @@ endif()
|
|||
add_library(kernel_executor SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor_impl.cc
|
||||
${TOP_DIR}/mindspore/lite/src/ops/ops_utils.cc
|
||||
${TOP_DIR}/mindspore/lite/src/common/ops/ops_utils.cc
|
||||
${TOP_DIR}/mindspore/lite/src/common/primitive_t_utils.cc
|
||||
${TOP_DIR}/mindspore/lite/src/ops/ops_def.cc)
|
||||
${TOP_DIR}/mindspore/lite/src/common/ops/ops_def.cc)
|
||||
|
||||
add_dependencies(kernel_executor fbs_inner_src fbs_src mindspore_core)
|
||||
|
||||
|
|
|
@ -43,6 +43,22 @@ KernelExecutorImpl::~KernelExecutorImpl() {
|
|||
|
||||
Status KernelExecutorImpl::Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
|
||||
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "base operator is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "wrong inputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (outputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "wrong outputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "context is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
data_type_ = static_cast<enum TypeId>(inputs[FIRST_INPUT].DataType());
|
||||
std::unique_ptr<mindspore::schema::PrimitiveT> prim_t = lite::GetPrimitiveT(op);
|
||||
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
|
||||
|
@ -85,11 +101,23 @@ Status KernelExecutorImpl::Build(const std::shared_ptr<ops::BaseOperator> &op, c
|
|||
}
|
||||
|
||||
Status KernelExecutorImpl::ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "wrong inputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (outputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "wrong outputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
Status status = InitInOutTensor(inputs, outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "InitInOutTensor error.";
|
||||
return status;
|
||||
}
|
||||
if (kernel_ == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
kernel_->set_in_tensors(inputs_);
|
||||
kernel_->set_out_tensors(outputs_);
|
||||
int ret;
|
||||
|
@ -106,6 +134,14 @@ Status KernelExecutorImpl::ReSize(const std::vector<MSTensor> &inputs, const std
|
|||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
Status KernelExecutorImpl::Infer(std::vector<MSTensor> *outputs) {
|
||||
if (outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "outputs is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (outputs->size() != outputs_.size()) {
|
||||
MS_LOG(ERROR) << "wrong outputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
auto user_output = outputs->at(i);
|
||||
auto output = outputs_[i];
|
||||
|
@ -120,15 +156,35 @@ Status KernelExecutorImpl::Infer(std::vector<MSTensor> *outputs) {
|
|||
}
|
||||
|
||||
Status KernelExecutorImpl::Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
if (inputs.size() != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "wrong inputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (outputs.size() != outputs_.size()) {
|
||||
MS_LOG(ERROR) << "wrong outputs size.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (kernel_ == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto user_input = inputs[i];
|
||||
auto input = inputs_[i];
|
||||
if (!TensorIsValid(user_input, input)) {
|
||||
MS_LOG(ERROR) << "inputs is invalid.";
|
||||
return kLiteError;
|
||||
}
|
||||
input->set_data(user_input.MutableData());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto user_output = outputs[i];
|
||||
auto output = outputs_[i];
|
||||
if (!TensorIsValid(user_output, output)) {
|
||||
MS_LOG(ERROR) << "outputs is invalid.";
|
||||
return kLiteError;
|
||||
}
|
||||
output->set_data(user_output.MutableData());
|
||||
}
|
||||
int ret = kernel_->Execute();
|
||||
|
@ -152,6 +208,7 @@ Status KernelExecutorImpl::GetOpParameter() {
|
|||
<< lite::GetPrimitiveTypeName(primitive_, schema_version_);
|
||||
return kLiteNullptr;
|
||||
}
|
||||
parameter_->thread_num_ = context_->thread_num_;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
|
@ -161,18 +218,18 @@ Status KernelExecutorImpl::GetCustomKernel(const std::shared_ptr<Context> &ms_co
|
|||
// find kernel match arch, data_type, kernel_arch and provider
|
||||
for (auto &&device : context_->device_list_) {
|
||||
if (!device.provider_.empty() && !device.provider_device_.empty()) {
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, device.provider_device_,
|
||||
device.provider_};
|
||||
get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
nullptr, &kernel_, primitive_);
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, NHWC, prim_type_,
|
||||
device.provider_device_, device.provider_};
|
||||
get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_, ms_context.get(),
|
||||
desc, nullptr, &kernel_, primitive_);
|
||||
}
|
||||
}
|
||||
|
||||
// find kernel only match arch and data_type
|
||||
if (get_kernel != RET_OK) {
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, "", ""};
|
||||
get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
nullptr, &kernel_, primitive_);
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, NHWC, prim_type_, "", ""};
|
||||
get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
nullptr, &kernel_, primitive_);
|
||||
}
|
||||
|
||||
// if found kernel, do infershape
|
||||
|
@ -190,9 +247,9 @@ Status KernelExecutorImpl::GetCpuKernel(const std::shared_ptr<Context> &ms_conte
|
|||
return status;
|
||||
}
|
||||
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_};
|
||||
int get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
parameter_, &kernel_);
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, NHWC, prim_type_};
|
||||
int get_kernel = lite::KernelRegistry::GetInstance()->GetKernelExec(inputs_, outputs_, context_, ms_context.get(),
|
||||
desc, parameter_, &kernel_);
|
||||
if (get_kernel == RET_OK) {
|
||||
int ret = KernelInferShape(inputs_, outputs_, parameter_);
|
||||
return static_cast<StatusCode>(ret);
|
||||
|
@ -250,4 +307,24 @@ Status KernelExecutorImpl::InitInOutTensor(const std::vector<MSTensor> &inputs,
|
|||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
bool KernelExecutorImpl::TensorIsValid(const MSTensor &ms_tensor, const lite::Tensor *lite_tensor) {
|
||||
if (static_cast<enum TypeId>(ms_tensor.DataType()) != lite_tensor->data_type()) {
|
||||
return false;
|
||||
}
|
||||
if (ms_tensor.format() != lite_tensor->format()) {
|
||||
return false;
|
||||
}
|
||||
auto ms_tensor_shape = ms_tensor.Shape();
|
||||
auto lite_tensor_shape = lite_tensor->shape();
|
||||
if (ms_tensor_shape.size() != lite_tensor_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < ms_tensor_shape.size(); i++) {
|
||||
if (ms_tensor_shape[i] != lite_tensor_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,7 @@ class KernelExecutorImpl {
|
|||
Status GetOpParameter();
|
||||
Status InitInOutTensor(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
|
||||
void FreeInOutTensor();
|
||||
bool TensorIsValid(const MSTensor &ms_tensor, const lite::Tensor *lite_tensor);
|
||||
|
||||
private:
|
||||
const schema::Primitive *primitive_ = nullptr;
|
||||
|
|
Loading…
Reference in New Issue