forked from mindspore-Ecosystem/mindspore
!15271 [MS][LITE]fix memory leak and add models to the entrance guard
From: @probiotics_53 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a22b89ef89
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "nnacl/infer/tensorlist_reserve_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
#include "nnacl/tensorlist_parameter.h"
|
||||
|
||||
int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
|
||||
size_t outputs_size, OpParameter *parameter) {
|
||||
|
@ -26,6 +27,7 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size
|
|||
}
|
||||
#endif
|
||||
|
||||
TensorListParameter *reserve_param = (TensorListParameter *)parameter;
|
||||
const TensorC *input0 = inputs[0];
|
||||
int ele_shape_type = input0->data_type_;
|
||||
if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) {
|
||||
|
@ -35,6 +37,7 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size
|
|||
TensorListC *output = (TensorListC *)(outputs[0]);
|
||||
output->data_type_ = kObjectTypeTensorType;
|
||||
output->format_ = Format_NHWC;
|
||||
output->tensors_data_type_ = reserve_param->element_dtype_;
|
||||
|
||||
if (input0->data_ == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
|
|
|
@ -39,6 +39,9 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
TensorC *output = outputs[0];
|
||||
|
||||
SetDataTypeFormat(output, input);
|
||||
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
|
||||
output->data_type_ = kNumberTypeFloat32;
|
||||
}
|
||||
if (!parameter->infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
|
|
@ -179,31 +179,35 @@ bool KernelRegistry::SupportKernel(const KernelKey &key) {
|
|||
return kernel_creator != nullptr;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, const InnerContext *ctx,
|
||||
const kernel::KernelKey &key, OpParameter *parameter,
|
||||
const void *primitive) {
|
||||
int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter,
|
||||
kernel::LiteKernel **kernel, const void *primitive) {
|
||||
MS_ASSERT(ctx != nullptr);
|
||||
MS_ASSERT(kernel != nullptr);
|
||||
if (key.vendor == kBuiltin) {
|
||||
auto creator = GetCreator(key);
|
||||
if (creator != nullptr) {
|
||||
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
|
||||
if (kernel != nullptr) {
|
||||
kernel->set_desc(key);
|
||||
return kernel;
|
||||
*kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
|
||||
if (*kernel != nullptr) {
|
||||
(*kernel)->set_desc(key);
|
||||
return RET_OK;
|
||||
}
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto creator = GetDelegateCreator(key);
|
||||
if (creator == nullptr) {
|
||||
return nullptr;
|
||||
if (creator != nullptr) {
|
||||
std::vector<tensor::MSTensor *> tensors_in;
|
||||
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
|
||||
std::vector<tensor::MSTensor *> tensors_out;
|
||||
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
|
||||
*kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
|
||||
if (*kernel != nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<tensor::MSTensor *> tensors_in;
|
||||
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
|
||||
std::vector<tensor::MSTensor *> tensors_out;
|
||||
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
|
||||
return creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
|
||||
}
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -48,9 +48,9 @@ class KernelRegistry {
|
|||
kernel::CreateKernel creator);
|
||||
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
|
||||
bool SupportKernel(const kernel::KernelKey &key);
|
||||
kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter,
|
||||
const void *primitive = nullptr);
|
||||
int GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter,
|
||||
kernel::LiteKernel **kernel, const void *primitive = nullptr);
|
||||
|
||||
protected:
|
||||
static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1};
|
||||
|
|
|
@ -151,6 +151,12 @@ lite::Tensor *LiteSession::ConvertTensor(const schema::Tensor &src_tensor) {
|
|||
lite::Tensor *dst_tensor = nullptr;
|
||||
if (TypeId(src_tensor.dataType()) == kObjectTypeTensorType) {
|
||||
dst_tensor = new (std::nothrow) TensorList(shape, std::vector<int>(), src_category);
|
||||
// set tensor list datatype
|
||||
auto tensor_list = reinterpret_cast<TensorList *>(dst_tensor);
|
||||
if (src_tensor.data() != nullptr) {
|
||||
auto tensor_data_type = TypeId(reinterpret_cast<const int *>(src_tensor.data()->data())[0]);
|
||||
tensor_list->set_tensors_data_type(tensor_data_type);
|
||||
}
|
||||
} else {
|
||||
dst_tensor = new (std::nothrow) Tensor(TypeId(src_tensor.dataType()), shape, src_tensor.format(), src_category);
|
||||
}
|
||||
|
|
|
@ -425,6 +425,7 @@ int ArithmeticCPUKernel::Run() {
|
|||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_AddFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_AddFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SubFusion, LiteKernelCreator<ArithmeticCPUKernel>)
|
||||
|
|
|
@ -340,26 +340,26 @@ inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tens
|
|||
}
|
||||
} // namespace
|
||||
|
||||
kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
|
||||
const kernel::KernelKey &desc, TypeId kernel_data_type) {
|
||||
int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
|
||||
kernel::LiteKernel **kernel) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
auto op_type = op_parameter->type_;
|
||||
if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
kernel::KernelKey cpu_desc = desc;
|
||||
if (kernel_data_type == kNumberTypeFloat16) {
|
||||
if (!context_->IsCpuFloat16Enabled() ||
|
||||
(cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) {
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
cpu_desc.data_type = kNumberTypeFloat16;
|
||||
}
|
||||
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
std::map<Tensor *, Tensor *> restored_origin_tensors;
|
||||
|
||||
|
@ -367,28 +367,27 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
|
|||
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
// we don't need to restore tensor for copy data
|
||||
ret = CopyConstTensorData(in_tensors, op_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter);
|
||||
if (kernel != nullptr) {
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter, kernel);
|
||||
if (ret == RET_OK) {
|
||||
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type);
|
||||
FreeRestoreTensors(&restored_origin_tensors);
|
||||
} else {
|
||||
RestoreTensorData(&restored_origin_tensors);
|
||||
}
|
||||
return kernel;
|
||||
return ret;
|
||||
} // namespace mindspore::lite
|
||||
|
||||
kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
|
||||
const kernel::KernelKey &desc) {
|
||||
int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
|
||||
if (context_->IsGpuEnabled()) {
|
||||
|
@ -402,30 +401,27 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_ten
|
|||
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
// we don't need to restore tensor for copy data
|
||||
ret = CopyConstTensorData(in_tensors, op_parameter->type_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter);
|
||||
if (kernel != nullptr) {
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter, kernel);
|
||||
if (ret == RET_OK) {
|
||||
MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
|
||||
}
|
||||
return kernel;
|
||||
} else {
|
||||
return nullptr;
|
||||
return ret;
|
||||
}
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
|
||||
const kernel::KernelKey &desc) {
|
||||
int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};
|
||||
if (context_->IsNpuEnabled()) {
|
||||
|
@ -435,23 +431,22 @@ kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_ten
|
|||
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
|
||||
return nullptr;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
for (auto tensor : in_tensors) {
|
||||
if (tensor->data_type() == kNumberTypeFloat16) {
|
||||
tensor->set_data_type(kNumberTypeFloat32);
|
||||
}
|
||||
}
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter);
|
||||
if (kernel != nullptr) {
|
||||
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter, kernel);
|
||||
if (ret == RET_OK) {
|
||||
MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type);
|
||||
}
|
||||
return kernel;
|
||||
} else {
|
||||
return nullptr;
|
||||
return ret;
|
||||
}
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
|
||||
|
@ -469,49 +464,16 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
bool infer_shape_interrupt = !op_parameter->infer_flag_;
|
||||
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
int status;
|
||||
#ifdef SUPPORT_GPU
|
||||
// if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) {
|
||||
kernel = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc);
|
||||
if (kernel != nullptr) {
|
||||
status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
|
||||
<< node->name_;
|
||||
auto ret = InferNodeShape(node, &infer_shape_interrupt);
|
||||
if (ret == RET_INFER_INVALID || ret == RET_OK) {
|
||||
op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// }
|
||||
#endif
|
||||
#ifdef SUPPORT_NPU
|
||||
// if (node->device_type_ == DT_NPU || node->device_type_ == DEFAULT) {
|
||||
kernel = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
|
||||
<< node->name_;
|
||||
auto ret = InferNodeShape(node, &infer_shape_interrupt);
|
||||
if (ret == RET_INFER_INVALID || ret == RET_OK) {
|
||||
op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// }
|
||||
#endif
|
||||
if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) {
|
||||
kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
|
||||
<< node->name_;
|
||||
if (status == RET_ERROR) {
|
||||
auto ret = InferNodeShape(node, &infer_shape_interrupt);
|
||||
if (ret == RET_INFER_INVALID || ret == RET_OK) {
|
||||
op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
|
@ -521,15 +483,55 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
}
|
||||
}
|
||||
}
|
||||
// }
|
||||
#endif
|
||||
#ifdef SUPPORT_NPU
|
||||
// if (node->device_type_ == DT_NPU || node->device_type_ == DEFAULT) {
|
||||
status = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
|
||||
<< node->name_;
|
||||
if (status == RET_ERROR) {
|
||||
auto ret = InferNodeShape(node, &infer_shape_interrupt);
|
||||
if (ret == RET_INFER_INVALID || ret == RET_OK) {
|
||||
op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
// }
|
||||
#endif
|
||||
if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) {
|
||||
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
|
||||
<< node->name_;
|
||||
if (status == RET_ERROR) {
|
||||
auto ret = InferNodeShape(node, &infer_shape_interrupt);
|
||||
if (ret == RET_INFER_INVALID || ret == RET_OK) {
|
||||
op_parameter = op_parameters_[node->output_indices_.at(0)];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_type == kNumberTypeFloat16) {
|
||||
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
||||
desc.data_type = kNumberTypeFloat32;
|
||||
}
|
||||
if (prefer_data_type == kNumberTypeFloat32 || prefer_data_type == kTypeUnknown) {
|
||||
kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32);
|
||||
if (kernel != nullptr) {
|
||||
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
} else {
|
||||
} else if (status == RET_ERROR) {
|
||||
auto ret = InferNodeShape(node, &infer_shape_interrupt);
|
||||
if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
|
||||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
|
|
|
@ -60,12 +60,13 @@ class Scheduler {
|
|||
kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors,
|
||||
const std::vector<Tensor *> &out_tensors, const Model::Node *node,
|
||||
TypeId prefer_data_type = kTypeUnknown);
|
||||
kernel::LiteKernel *FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type);
|
||||
kernel::LiteKernel *FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc);
|
||||
kernel::LiteKernel *FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc);
|
||||
int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
|
||||
kernel::LiteKernel **kernel);
|
||||
int FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
|
||||
int FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
|
||||
// schedule a partial node to a subgraph_kernel
|
||||
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
|
||||
// schedule a node to a kernel
|
||||
|
|
|
@ -95,7 +95,7 @@ ml_video_edit_img_segment 3
|
|||
ml_video_edit_video_segment_gauss_adaptis_part1 5
|
||||
# When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error.
|
||||
ml_handpose 175
|
||||
hdc_Face_Aesthetic_MTI_Aesthetic 22
|
||||
hdc_Face_Aesthetic_MTI_Aesthetic 0.5
|
||||
ml_face_compare 5.5
|
||||
ml_face_tracking 2.5
|
||||
ml_face_beard 0.5
|
||||
|
|
|
@ -73,7 +73,7 @@ mtk_face_recognition_v3.onnx
|
|||
mtk_face_recognition_v2.onnx
|
||||
ml_2012_ocr_detection_tmp.onnx
|
||||
ml_video_edit_enhance_update_tmp.onnx
|
||||
#Harmony_Voiceprint_resnet18.onnx
|
||||
Harmony_Voiceprint_resnet18.onnx;1,150,40,1
|
||||
bloom_hongmo_detection_tmp.onnx
|
||||
Q_face_recognition.onnx
|
||||
Q888_face_recognition.onnx
|
||||
|
|
|
@ -79,7 +79,7 @@ mtk_face_features_v2.onnx;1,256,192,3 0.5
|
|||
mtk_face_recognition_v3.onnx 0.5
|
||||
mtk_face_recognition_v2.onnx 2.5
|
||||
ml_2012_ocr_detection_tmp.onnx 0.5
|
||||
#Harmony_Voiceprint_resnet18.onnx;1,1,200,40 4.5
|
||||
Harmony_Voiceprint_resnet18.onnx;1,150,40,1 4.5
|
||||
bloom_hongmo_detection_tmp.onnx 0.5
|
||||
Q_face_recognition.onnx 2
|
||||
ml_video_edit_enhance_update_tmp.onnx 0.5
|
||||
|
|
|
@ -36,7 +36,6 @@ mnasnet_1.3_224.tflite
|
|||
inception_v3.tflite
|
||||
deeplabv3_257_mv_gpu.tflite
|
||||
multi_person_mobilenet_v1_075_float.tflite
|
||||
#hiai_vad.tflite
|
||||
ide_label_base.tflite
|
||||
ide_label_retrained.tflite
|
||||
ml_ei_headpose.tflite
|
||||
|
@ -164,8 +163,6 @@ hiai_detectmodel_desnet_256_128_64_32.tflite
|
|||
lite-model_aiy_vision_classifier_food_V1_1.tflite
|
||||
lite-model_disease-classification_1.tflite
|
||||
lite-model_models_mushroom-identification_v1_1.tflite
|
||||
#lite-model_albert_lite_base_squadv1_metadata_1.tflite
|
||||
#lite-model_mobilebert_1_metadata_1.tflite
|
||||
smartreply_1_default_1.tflite
|
||||
text_classification.tflite
|
||||
Q_detect_fpn_add_inception-1448650.tflite
|
||||
|
@ -183,3 +180,8 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
|
|||
Q888_face_emo_dress_mv3_orderd.tflite
|
||||
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite
|
||||
bloom_new_detect.tflite
|
||||
bloom_model_age_gender.tflite
|
||||
bloom_isface.tflite
|
||||
hiai_object_detect_814.tflite
|
||||
hiai_object_tflite_graph_8bit.tflite
|
||||
|
|
|
@ -209,3 +209,9 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2
|
|||
Q888_face_emo_dress_mv3_orderd.tflite 2.5
|
||||
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5
|
||||
bloom_new_detect.tflite 3.5
|
||||
bloom_model_age_gender.tflite 0.5
|
||||
bloom_isface.tflite 0.5
|
||||
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
|
||||
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.
|
||||
hiai_object_detect_814.tflite 14
|
||||
|
|
|
@ -8,6 +8,7 @@ ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2
|
|||
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2
|
||||
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2
|
||||
decoder.onnx;2;1,7,512:1,7
|
||||
#fasterrcnn_crop.pb is the same model as gts_object_detect_Ics.pb.
|
||||
fasterrcnn_crop.pb;1;420,630,3
|
||||
ml_video_edit_person_divison_video;2
|
||||
hdc_tb_cn_neg.tflite;3
|
||||
|
@ -31,4 +32,11 @@ add_uint8.tflite;2
|
|||
ml_Heatmap_depth_240180;2
|
||||
ml_Heatmap_depth_180240;2
|
||||
hiai_nlu_model.pb;3;1,16:1,16:1,16
|
||||
gts_object_detect_lcs.pb;1;420,630,3
|
||||
#calib data file in server incorrect
|
||||
#gts_object_detect_Ics.pb;1;420,630,3
|
||||
ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16
|
||||
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16
|
||||
hiai_transformer_encoder.pb;15
|
||||
lite-model_albert_lite_base_squadv1_metadata_1.tflite;3
|
||||
lite-model_mobilebert_1_metadata_1.tflite;3
|
||||
hiai_vad.tflite;2
|
||||
|
|
|
@ -26,3 +26,6 @@ ml_tts_vocoder.pb;66 53
|
|||
# The outputs of two Heatmap_depth models have small value
|
||||
ml_Heatmap_depth_240180;2 10 16
|
||||
ml_Heatmap_depth_180240;2 7 7
|
||||
ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 1
|
||||
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 0.5
|
||||
hiai_transformer_encoder.pb;15 4
|
||||
|
|
|
@ -63,7 +63,7 @@ void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, st
|
|||
if (!tensorT->data.empty()) {
|
||||
int *data = reinterpret_cast<int *>(tensorT->data.data());
|
||||
type = TypeId(data[0]);
|
||||
if (tensorT->data.size() < 8 || (data[1] + 2) * 4 != static_cast<int>(tensorT->data.size())) {
|
||||
if (tensorT->data.size() < 8 || (data[1] != 0 && (data[1] + 2) * 4 != static_cast<int>(tensorT->data.size()))) {
|
||||
MS_LOG(ERROR) << "tensorlist data length illegal";
|
||||
*convert_succ = false;
|
||||
return;
|
||||
|
@ -229,29 +229,36 @@ void SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
|
|||
output_tensor->dataType = output_tensors[i]->data_type();
|
||||
if (output_tensors[i]->data_type() == kObjectTypeTensorType) {
|
||||
auto tensor_list = reinterpret_cast<TensorList *>(output_tensors[i]);
|
||||
if (tensor_list->tensors_data_type() == kTypeUnknown) {
|
||||
tensors_->at(node->outputIndex[i]).is_infer_ = false;
|
||||
if (output_tensor->data.empty()) {
|
||||
output_tensor->data.resize(8, 0);
|
||||
}
|
||||
if (tensor_list->tensors_data_type() == kTypeUnknown) {
|
||||
tensors_->at(node->outputIndex[i]).is_inferred_ = false;
|
||||
return;
|
||||
}
|
||||
output_tensor->data.at(0) = tensor_list->tensors_data_type();
|
||||
} else if (output_tensors[i]->data_type() == kTypeUnknown) {
|
||||
tensors_->at(node->outputIndex[i]).is_inferred_ = false;
|
||||
return;
|
||||
}
|
||||
tensors_->at(node->outputIndex[i]).is_inferred_ = true;
|
||||
return;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
STATUS InferShapePass::Run(MetaGraphT *graph) {
|
||||
graph_ = graph;
|
||||
InitSearchTensor(graph);
|
||||
MS_ASSERT(graph != nullptr);
|
||||
for (auto idx : graph->inputIndex) {
|
||||
auto input_tensor = graph->allTensors[idx].get();
|
||||
InitSearchTensor(graph);
|
||||
for (auto input_idx : graph->inputIndex) {
|
||||
auto input_tensor = graph->allTensors[input_idx].get();
|
||||
for (auto &dim : input_tensor->dims) {
|
||||
if (dim == 0) {
|
||||
MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value.";
|
||||
dim = DEFAULT_DIM_VALUE;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto g_input_idx : graph->inputIndex) {
|
||||
auto g_input_shape = graph->allTensors.at(g_input_idx)->dims;
|
||||
if (std::find(g_input_shape.begin(), g_input_shape.end(), -1) != g_input_shape.end() || fmk_type_ == FmkType_TF) {
|
||||
auto input_shape = graph->allTensors.at(input_idx)->dims;
|
||||
if (std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() || fmk_type_ == FmkType_TF) {
|
||||
infer_interrupt_ = true;
|
||||
}
|
||||
}
|
||||
|
@ -286,11 +293,11 @@ STATUS InferShapePass::Run(MetaGraphT *graph) {
|
|||
auto output_dims = output_tensors[i]->shape();
|
||||
auto &output_tensor = graph->allTensors.at(node->outputIndex[i]);
|
||||
output_tensor->dims.swap(output_dims);
|
||||
SetDataType(graph_, output_tensors, &tensors_, i, infer_node_index);
|
||||
SetDataType(graph, output_tensors, &tensors_, i, infer_node_index);
|
||||
}
|
||||
} else if (status == RET_INFER_INVALID) {
|
||||
for (size_t i = 0; i < output_tensors.size(); i++) {
|
||||
SetDataType(graph_, output_tensors, &tensors_, i, infer_node_index);
|
||||
SetDataType(graph, output_tensors, &tensors_, i, infer_node_index);
|
||||
}
|
||||
infer_interrupt_ = true;
|
||||
} else {
|
||||
|
@ -300,7 +307,7 @@ STATUS InferShapePass::Run(MetaGraphT *graph) {
|
|||
return RET_INFER_ERR;
|
||||
}
|
||||
FreeTensors(&input_tensors, &output_tensors);
|
||||
AddOutputNode(infer_node_index);
|
||||
AddOutputNodes(graph, infer_node_index);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -313,82 +320,60 @@ void InferShapePass::InitSearchTensor(MetaGraphT *graph) {
|
|||
auto node_input_indexes = node->inputIndex;
|
||||
// init in_nodes index
|
||||
for (size_t j = 0; j < node_input_indexes.size(); j++) {
|
||||
tensors_[node_input_indexes[j]].in_nodes_.push_back(i);
|
||||
tensors_[node_input_indexes[j]].next_nodes_.push_back(i);
|
||||
}
|
||||
auto node_output_indexes = node->outputIndex;
|
||||
for (size_t j = 0; j < node_output_indexes.size(); j++) {
|
||||
tensors_[node_output_indexes[j]].out_nodes_.push_back(i);
|
||||
all_node_output_tensor_indexes.insert(all_node_output_tensor_indexes.end(), node_output_indexes.begin(),
|
||||
node_output_indexes.end());
|
||||
tensors_[node_output_indexes[j]].prev_nodes_.push_back(i);
|
||||
}
|
||||
all_node_output_tensor_indexes.insert(all_node_output_tensor_indexes.end(), node_output_indexes.begin(),
|
||||
node_output_indexes.end());
|
||||
}
|
||||
for (uint32_t i = 0; i < tensors_.size(); i++) {
|
||||
if (tensors_[i].prev_nodes_.empty() || IsContain(graph->inputIndex, i) || !graph->allTensors.at(i)->data.empty()) {
|
||||
tensors_[i].is_inferred_ = true;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < graph->nodes.size(); i++) {
|
||||
auto &node = graph->nodes[i];
|
||||
auto &node = graph->nodes.at(i);
|
||||
if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
|
||||
[&](uint32_t index) { return !IsContain(all_node_output_tensor_indexes, index); })) {
|
||||
[&](uint32_t idx) { return tensors_[idx].is_inferred_; })) {
|
||||
infer_node_indexes_.push_back(i);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < tensors_.size(); i++) {
|
||||
if (tensors_[i].out_nodes_.empty()) {
|
||||
tensors_[i].is_infer_ = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InferShapePass::AddOutputNode(uint32_t infer_node_index) {
|
||||
auto &node = graph_->nodes[infer_node_index];
|
||||
void InferShapePass::AddOutputNodes(MetaGraphT *graph, uint32_t infer_node_index) {
|
||||
auto &node = graph->nodes.at(infer_node_index);
|
||||
for (size_t i = 0; i < node->outputIndex.size(); i++) {
|
||||
auto output_tensor_node_indexes = tensors_[node->outputIndex[i]].in_nodes_;
|
||||
tensors_[node->outputIndex[i]].is_infer_ = true;
|
||||
for (size_t j = 0; j < output_tensor_node_indexes.size(); j++) {
|
||||
bool flag = false;
|
||||
auto &output_tensor_node = graph_->nodes[output_tensor_node_indexes[j]];
|
||||
for (size_t k = 0; k < output_tensor_node->outputIndex.size(); k++) {
|
||||
if (graph_->allTensors.at(output_tensor_node->outputIndex[k])->dataType != kObjectTypeTensorType) {
|
||||
if (graph_->allTensors.at(output_tensor_node->outputIndex[k])->dataType == kTypeUnknown ||
|
||||
tensors_[output_tensor_node->outputIndex[k]].is_infer_ == false) {
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (tensors_[output_tensor_node->outputIndex[k]].is_infer_ == false) {
|
||||
flag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (flag) {
|
||||
AddNextInferShapeNode(output_tensor_node_indexes, j);
|
||||
auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_;
|
||||
for (size_t j = 0; j < next_nodes_indexes.size(); j++) {
|
||||
auto &next_node = graph->nodes.at(next_nodes_indexes[j]);
|
||||
if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(),
|
||||
[&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) {
|
||||
AddNextInferShapeNode(graph, next_nodes_indexes, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InferShapePass::AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index) {
|
||||
auto &output_tensor_node = graph_->nodes.at(output_tensor_node_indexes[index]);
|
||||
if (find(infer_node_indexes_.begin(), infer_node_indexes_.end(), output_tensor_node_indexes[index]) ==
|
||||
void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> next_nodes_indexes, size_t index) {
|
||||
auto &next_node = graph->nodes.at(next_nodes_indexes[index]);
|
||||
if (find(infer_node_indexes_.begin(), infer_node_indexes_.end(), next_nodes_indexes[index]) ==
|
||||
infer_node_indexes_.end()) {
|
||||
auto output_tensor_node_type = output_tensor_node->primitive->value.type;
|
||||
if (output_tensor_node_type == schema::PrimitiveType_Merge) {
|
||||
if (std::all_of(output_tensor_node->inputIndex.begin(),
|
||||
output_tensor_node->inputIndex.begin() + output_tensor_node->inputIndex.size() / 2,
|
||||
[&](uint32_t k) { return tensors_[k].is_infer_; }) ||
|
||||
std::all_of(output_tensor_node->inputIndex.begin() + output_tensor_node->inputIndex.size() / 2,
|
||||
output_tensor_node->inputIndex.end(), [&](uint32_t k) { return tensors_[k].is_infer_; })) {
|
||||
infer_node_indexes_.push_back(output_tensor_node_indexes[index]);
|
||||
}
|
||||
} else {
|
||||
bool flag = true;
|
||||
for (size_t i = 0; i < output_tensor_node->inputIndex.size(); i++) {
|
||||
if (!(tensors_[output_tensor_node->inputIndex[i]].is_infer_)) {
|
||||
flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (flag) {
|
||||
infer_node_indexes_.push_back(output_tensor_node_indexes[index]);
|
||||
auto next_node_type = next_node->primitive->value.type;
|
||||
if (next_node_type == schema::PrimitiveType_Merge) {
|
||||
if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.begin() + next_node->inputIndex.size() / 2,
|
||||
[&](uint32_t i) { return tensors_[i].is_inferred_; }) ||
|
||||
std::all_of(next_node->inputIndex.begin() + next_node->inputIndex.size() / 2, next_node->inputIndex.end(),
|
||||
[&](uint32_t i) { return tensors_[i].is_inferred_; })) {
|
||||
infer_node_indexes_.push_back(next_nodes_indexes[index]);
|
||||
}
|
||||
} else if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
|
||||
[&](uint32_t i) { return tensors_[i].is_inferred_; }) ||
|
||||
std::any_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
|
||||
[&](uint32_t i) { return graph->allTensors.at(i)->dataType == kObjectTypeTensorType; })) {
|
||||
infer_node_indexes_.push_back(next_nodes_indexes[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,9 +32,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
|
||||
struct InferTensor {
|
||||
std::vector<uint32_t> in_nodes_; /* used current tensor as input */
|
||||
std::vector<uint32_t> out_nodes_;
|
||||
bool is_infer_;
|
||||
std::vector<uint32_t> next_nodes_;
|
||||
std::vector<uint32_t> prev_nodes_;
|
||||
bool is_inferred_;
|
||||
};
|
||||
|
||||
class InferShapePass : public GraphPass {
|
||||
|
@ -45,11 +45,10 @@ class InferShapePass : public GraphPass {
|
|||
|
||||
private:
|
||||
void InitSearchTensor(MetaGraphT *graph);
|
||||
void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index);
|
||||
void AddOutputNode(uint32_t infer_node_index);
|
||||
void AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> next_nodes_indexes, size_t index);
|
||||
void AddOutputNodes(MetaGraphT *graph, uint32_t infer_node_index);
|
||||
|
||||
lite::converter::FmkType fmk_type_ = FmkType_TF;
|
||||
MetaGraphT *graph_ = nullptr;
|
||||
std::vector<InferTensor> tensors_ = {};
|
||||
std::vector<uint32_t> infer_node_indexes_ = {};
|
||||
bool infer_interrupt_ = false;
|
||||
|
|
Loading…
Reference in New Issue