!15271 [MS][LITE]fix memory leak and add models to the entrance guard

From: @probiotics_53
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-29 09:03:42 +08:00 committed by Gitee
commit a22b89ef89
17 changed files with 202 additions and 179 deletions

View File

@ -16,6 +16,7 @@
#include "nnacl/infer/tensorlist_reserve_infer.h"
#include "nnacl/infer/infer_register.h"
#include "nnacl/tensorlist_parameter.h"
int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter) {
@ -26,6 +27,7 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size
}
#endif
TensorListParameter *reserve_param = (TensorListParameter *)parameter;
const TensorC *input0 = inputs[0];
int ele_shape_type = input0->data_type_;
if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) {
@ -35,6 +37,7 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size
TensorListC *output = (TensorListC *)(outputs[0]);
output->data_type_ = kObjectTypeTensorType;
output->format_ = Format_NHWC;
output->tensors_data_type_ = reserve_param->element_dtype_;
if (input0->data_ == NULL) {
return NNACL_INFER_INVALID;

View File

@ -39,6 +39,9 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
TensorC *output = outputs[0];
SetDataTypeFormat(output, input);
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
output->data_type_ = kNumberTypeFloat32;
}
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}

View File

@ -179,31 +179,35 @@ bool KernelRegistry::SupportKernel(const KernelKey &key) {
return kernel_creator != nullptr;
}
kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, const InnerContext *ctx,
const kernel::KernelKey &key, OpParameter *parameter,
const void *primitive) {
int KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter,
kernel::LiteKernel **kernel, const void *primitive) {
MS_ASSERT(ctx != nullptr);
MS_ASSERT(kernel != nullptr);
if (key.vendor == kBuiltin) {
auto creator = GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
if (kernel != nullptr) {
kernel->set_desc(key);
return kernel;
*kernel = creator(in_tensors, out_tensors, parameter, ctx, key);
if (*kernel != nullptr) {
(*kernel)->set_desc(key);
return RET_OK;
}
return RET_ERROR;
}
} else {
auto creator = GetDelegateCreator(key);
if (creator == nullptr) {
return nullptr;
if (creator != nullptr) {
std::vector<tensor::MSTensor *> tensors_in;
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
std::vector<tensor::MSTensor *> tensors_out;
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
*kernel = creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
if (*kernel != nullptr) {
return RET_OK;
}
return RET_ERROR;
}
std::vector<tensor::MSTensor *> tensors_in;
Tensor2MSTensor(std::move(in_tensors), &tensors_in);
std::vector<tensor::MSTensor *> tensors_out;
Tensor2MSTensor(std::move(out_tensors), &tensors_out);
return creator(tensors_in, tensors_out, static_cast<const schema::Primitive *>(primitive), ctx);
}
return nullptr;
return RET_NOT_SUPPORT;
}
} // namespace mindspore::lite

View File

@ -48,9 +48,9 @@ class KernelRegistry {
kernel::CreateKernel creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
bool SupportKernel(const kernel::KernelKey &key);
kernel::LiteKernel *GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter,
const void *primitive = nullptr);
int GetKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter,
kernel::LiteKernel **kernel, const void *primitive = nullptr);
protected:
static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1};

View File

@ -151,6 +151,12 @@ lite::Tensor *LiteSession::ConvertTensor(const schema::Tensor &src_tensor) {
lite::Tensor *dst_tensor = nullptr;
if (TypeId(src_tensor.dataType()) == kObjectTypeTensorType) {
dst_tensor = new (std::nothrow) TensorList(shape, std::vector<int>(), src_category);
// set tensor list datatype
auto tensor_list = reinterpret_cast<TensorList *>(dst_tensor);
if (src_tensor.data() != nullptr) {
auto tensor_data_type = TypeId(reinterpret_cast<const int *>(src_tensor.data()->data())[0]);
tensor_list->set_tensors_data_type(tensor_data_type);
}
} else {
dst_tensor = new (std::nothrow) Tensor(TypeId(src_tensor.dataType()), shape, src_tensor.format(), src_category);
}

View File

@ -425,6 +425,7 @@ int ArithmeticCPUKernel::Run() {
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_MulFusion, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_AddFusion, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddFusion, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_AddFusion, LiteKernelCreator<ArithmeticCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SubFusion, LiteKernelCreator<ArithmeticCPUKernel>)

View File

@ -340,26 +340,26 @@ inline void RestoreTensorData(std::map<Tensor *, Tensor *> *restored_origin_tens
}
} // namespace
kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
const kernel::KernelKey &desc, TypeId kernel_data_type) {
int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
kernel::LiteKernel **kernel) {
MS_ASSERT(op_parameter != nullptr);
auto op_type = op_parameter->type_;
if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
return nullptr;
return RET_NOT_SUPPORT;
}
kernel::KernelKey cpu_desc = desc;
if (kernel_data_type == kNumberTypeFloat16) {
if (!context_->IsCpuFloat16Enabled() ||
(cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) {
return nullptr;
return RET_NOT_SUPPORT;
}
cpu_desc.data_type = kNumberTypeFloat16;
}
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return nullptr;
return RET_NOT_SUPPORT;
}
std::map<Tensor *, Tensor *> restored_origin_tensors;
@ -367,28 +367,27 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
return nullptr;
return RET_NOT_SUPPORT;
}
// we don't need to restore tensor for copy data
ret = CopyConstTensorData(in_tensors, op_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
return nullptr;
return RET_NOT_SUPPORT;
}
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter);
if (kernel != nullptr) {
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter, kernel);
if (ret == RET_OK) {
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type);
FreeRestoreTensors(&restored_origin_tensors);
} else {
RestoreTensorData(&restored_origin_tensors);
}
return kernel;
return ret;
} // namespace mindspore::lite
kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
const kernel::KernelKey &desc) {
int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel) {
MS_ASSERT(op_parameter != nullptr);
if (context_->IsGpuEnabled()) {
@ -402,30 +401,27 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_ten
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return nullptr;
return RET_NOT_SUPPORT;
}
// we don't need to restore tensor for copy data
ret = CopyConstTensorData(in_tensors, op_parameter->type_);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
return nullptr;
return RET_NOT_SUPPORT;
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter);
if (kernel != nullptr) {
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter, kernel);
if (ret == RET_OK) {
MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type);
} else {
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type);
}
return kernel;
} else {
return nullptr;
return ret;
}
return RET_NOT_SUPPORT;
}
kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, OpParameter *op_parameter,
const kernel::KernelKey &desc) {
int Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel) {
MS_ASSERT(op_parameter != nullptr);
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};
if (context_->IsNpuEnabled()) {
@ -435,23 +431,22 @@ kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector<Tensor *> &in_ten
auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return nullptr;
return RET_NOT_SUPPORT;
}
for (auto tensor : in_tensors) {
if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter);
if (kernel != nullptr) {
ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter, kernel);
if (ret == RET_OK) {
MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type);
} else {
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type);
}
return kernel;
} else {
return nullptr;
return ret;
}
return RET_NOT_SUPPORT;
}
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
@ -469,49 +464,16 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
bool infer_shape_interrupt = !op_parameter->infer_flag_;
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
kernel::LiteKernel *kernel = nullptr;
int status;
#ifdef SUPPORT_GPU
// if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) {
kernel = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc);
if (kernel != nullptr) {
status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
if (status == RET_OK) {
return kernel;
} else {
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
<< node->name_;
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (ret == RET_INFER_INVALID || ret == RET_OK) {
op_parameter = op_parameters_[node->output_indices_.at(0)];
} else {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
return nullptr;
}
}
// }
#endif
#ifdef SUPPORT_NPU
// if (node->device_type_ == DT_NPU || node->device_type_ == DEFAULT) {
kernel = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc);
if (kernel != nullptr) {
return kernel;
} else {
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
<< node->name_;
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (ret == RET_INFER_INVALID || ret == RET_OK) {
op_parameter = op_parameters_[node->output_indices_.at(0)];
} else {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
return nullptr;
}
}
// }
#endif
if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) {
kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16);
if (kernel != nullptr) {
return kernel;
} else {
MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
<< node->name_;
if (status == RET_ERROR) {
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (ret == RET_INFER_INVALID || ret == RET_OK) {
op_parameter = op_parameters_[node->output_indices_.at(0)];
@ -521,15 +483,55 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
}
}
}
// }
#endif
#ifdef SUPPORT_NPU
// if (node->device_type_ == DT_NPU || node->device_type_ == DEFAULT) {
status = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
if (status == RET_OK) {
return kernel;
} else {
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
<< node->name_;
if (status == RET_ERROR) {
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (ret == RET_INFER_INVALID || ret == RET_OK) {
op_parameter = op_parameters_[node->output_indices_.at(0)];
} else {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
return nullptr;
}
}
}
// }
#endif
if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) {
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
if (status == RET_OK) {
return kernel;
} else {
MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " "
<< node->name_;
if (status == RET_ERROR) {
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (ret == RET_INFER_INVALID || ret == RET_OK) {
op_parameter = op_parameters_[node->output_indices_.at(0)];
} else {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
return nullptr;
}
}
}
}
if (data_type == kNumberTypeFloat16) {
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32;
}
if (prefer_data_type == kNumberTypeFloat32 || prefer_data_type == kTypeUnknown) {
kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32);
if (kernel != nullptr) {
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel);
if (status == RET_OK) {
return kernel;
} else {
} else if (status == RET_ERROR) {
auto ret = InferNodeShape(node, &infer_shape_interrupt);
if (!(ret == RET_INFER_INVALID || ret == RET_OK)) {
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;

View File

@ -60,12 +60,13 @@ class Scheduler {
kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, const Model::Node *node,
TypeId prefer_data_type = kTypeUnknown);
kernel::LiteKernel *FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type);
kernel::LiteKernel *FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc);
kernel::LiteKernel *FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc);
int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
kernel::LiteKernel **kernel);
int FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
int FindNpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
// schedule a partial node to a subgraph_kernel
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
// schedule a node to a kernel

View File

@ -95,7 +95,7 @@ ml_video_edit_img_segment 3
ml_video_edit_video_segment_gauss_adaptis_part1 5
# When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error.
ml_handpose 175
hdc_Face_Aesthetic_MTI_Aesthetic 22
hdc_Face_Aesthetic_MTI_Aesthetic 0.5
ml_face_compare 5.5
ml_face_tracking 2.5
ml_face_beard 0.5

View File

@ -73,7 +73,7 @@ mtk_face_recognition_v3.onnx
mtk_face_recognition_v2.onnx
ml_2012_ocr_detection_tmp.onnx
ml_video_edit_enhance_update_tmp.onnx
#Harmony_Voiceprint_resnet18.onnx
Harmony_Voiceprint_resnet18.onnx;1,150,40,1
bloom_hongmo_detection_tmp.onnx
Q_face_recognition.onnx
Q888_face_recognition.onnx

View File

@ -79,7 +79,7 @@ mtk_face_features_v2.onnx;1,256,192,3 0.5
mtk_face_recognition_v3.onnx 0.5
mtk_face_recognition_v2.onnx 2.5
ml_2012_ocr_detection_tmp.onnx 0.5
#Harmony_Voiceprint_resnet18.onnx;1,1,200,40 4.5
Harmony_Voiceprint_resnet18.onnx;1,150,40,1 4.5
bloom_hongmo_detection_tmp.onnx 0.5
Q_face_recognition.onnx 2
ml_video_edit_enhance_update_tmp.onnx 0.5

View File

@ -36,7 +36,6 @@ mnasnet_1.3_224.tflite
inception_v3.tflite
deeplabv3_257_mv_gpu.tflite
multi_person_mobilenet_v1_075_float.tflite
#hiai_vad.tflite
ide_label_base.tflite
ide_label_retrained.tflite
ml_ei_headpose.tflite
@ -164,8 +163,6 @@ hiai_detectmodel_desnet_256_128_64_32.tflite
lite-model_aiy_vision_classifier_food_V1_1.tflite
lite-model_disease-classification_1.tflite
lite-model_models_mushroom-identification_v1_1.tflite
#lite-model_albert_lite_base_squadv1_metadata_1.tflite
#lite-model_mobilebert_1_metadata_1.tflite
smartreply_1_default_1.tflite
text_classification.tflite
Q_detect_fpn_add_inception-1448650.tflite
@ -183,3 +180,8 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
Q888_face_emo_dress_mv3_orderd.tflite
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite
Q_iMaxSR_RGB_385_p_pb2tflite.tflite
bloom_new_detect.tflite
bloom_model_age_gender.tflite
bloom_isface.tflite
hiai_object_detect_814.tflite
hiai_object_tflite_graph_8bit.tflite

View File

@ -209,3 +209,9 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2
Q888_face_emo_dress_mv3_orderd.tflite 2.5
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5
bloom_new_detect.tflite 3.5
bloom_model_age_gender.tflite 0.5
bloom_isface.tflite 0.5
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.
hiai_object_detect_814.tflite 14

View File

@ -8,6 +8,7 @@ ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2
ml_video_edit_video_segment_gauss_adaptis_part2.pb;2
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2
decoder.onnx;2;1,7,512:1,7
#fasterrcnn_crop.pb is the same model as gts_object_detect_Ics.pb.
fasterrcnn_crop.pb;1;420,630,3
ml_video_edit_person_divison_video;2
hdc_tb_cn_neg.tflite;3
@ -31,4 +32,11 @@ add_uint8.tflite;2
ml_Heatmap_depth_240180;2
ml_Heatmap_depth_180240;2
hiai_nlu_model.pb;3;1,16:1,16:1,16
gts_object_detect_lcs.pb;1;420,630,3
#calib data file in server incorrect
#gts_object_detect_Ics.pb;1;420,630,3
ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16
hiai_transformer_encoder.pb;15
lite-model_albert_lite_base_squadv1_metadata_1.tflite;3
lite-model_mobilebert_1_metadata_1.tflite;3
hiai_vad.tflite;2

View File

@ -26,3 +26,6 @@ ml_tts_vocoder.pb;66 53
# The outputs of two Heatmap_depth models have small value
ml_Heatmap_depth_240180;2 10 16
ml_Heatmap_depth_180240;2 7 7
ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 1
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 0.5
hiai_transformer_encoder.pb;15 4

View File

@ -63,7 +63,7 @@ void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, st
if (!tensorT->data.empty()) {
int *data = reinterpret_cast<int *>(tensorT->data.data());
type = TypeId(data[0]);
if (tensorT->data.size() < 8 || (data[1] + 2) * 4 != static_cast<int>(tensorT->data.size())) {
if (tensorT->data.size() < 8 || (data[1] != 0 && (data[1] + 2) * 4 != static_cast<int>(tensorT->data.size()))) {
MS_LOG(ERROR) << "tensorlist data length illegal";
*convert_succ = false;
return;
@ -229,29 +229,36 @@ void SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
output_tensor->dataType = output_tensors[i]->data_type();
if (output_tensors[i]->data_type() == kObjectTypeTensorType) {
auto tensor_list = reinterpret_cast<TensorList *>(output_tensors[i]);
if (tensor_list->tensors_data_type() == kTypeUnknown) {
tensors_->at(node->outputIndex[i]).is_infer_ = false;
if (output_tensor->data.empty()) {
output_tensor->data.resize(8, 0);
}
if (tensor_list->tensors_data_type() == kTypeUnknown) {
tensors_->at(node->outputIndex[i]).is_inferred_ = false;
return;
}
output_tensor->data.at(0) = tensor_list->tensors_data_type();
} else if (output_tensors[i]->data_type() == kTypeUnknown) {
tensors_->at(node->outputIndex[i]).is_inferred_ = false;
return;
}
tensors_->at(node->outputIndex[i]).is_inferred_ = true;
return;
}
} // namespace
STATUS InferShapePass::Run(MetaGraphT *graph) {
graph_ = graph;
InitSearchTensor(graph);
MS_ASSERT(graph != nullptr);
for (auto idx : graph->inputIndex) {
auto input_tensor = graph->allTensors[idx].get();
InitSearchTensor(graph);
for (auto input_idx : graph->inputIndex) {
auto input_tensor = graph->allTensors[input_idx].get();
for (auto &dim : input_tensor->dims) {
if (dim == 0) {
MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value.";
dim = DEFAULT_DIM_VALUE;
}
}
}
for (auto g_input_idx : graph->inputIndex) {
auto g_input_shape = graph->allTensors.at(g_input_idx)->dims;
if (std::find(g_input_shape.begin(), g_input_shape.end(), -1) != g_input_shape.end() || fmk_type_ == FmkType_TF) {
auto input_shape = graph->allTensors.at(input_idx)->dims;
if (std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() || fmk_type_ == FmkType_TF) {
infer_interrupt_ = true;
}
}
@ -286,11 +293,11 @@ STATUS InferShapePass::Run(MetaGraphT *graph) {
auto output_dims = output_tensors[i]->shape();
auto &output_tensor = graph->allTensors.at(node->outputIndex[i]);
output_tensor->dims.swap(output_dims);
SetDataType(graph_, output_tensors, &tensors_, i, infer_node_index);
SetDataType(graph, output_tensors, &tensors_, i, infer_node_index);
}
} else if (status == RET_INFER_INVALID) {
for (size_t i = 0; i < output_tensors.size(); i++) {
SetDataType(graph_, output_tensors, &tensors_, i, infer_node_index);
SetDataType(graph, output_tensors, &tensors_, i, infer_node_index);
}
infer_interrupt_ = true;
} else {
@ -300,7 +307,7 @@ STATUS InferShapePass::Run(MetaGraphT *graph) {
return RET_INFER_ERR;
}
FreeTensors(&input_tensors, &output_tensors);
AddOutputNode(infer_node_index);
AddOutputNodes(graph, infer_node_index);
}
return RET_OK;
}
@ -313,82 +320,60 @@ void InferShapePass::InitSearchTensor(MetaGraphT *graph) {
auto node_input_indexes = node->inputIndex;
// init in_nodes index
for (size_t j = 0; j < node_input_indexes.size(); j++) {
tensors_[node_input_indexes[j]].in_nodes_.push_back(i);
tensors_[node_input_indexes[j]].next_nodes_.push_back(i);
}
auto node_output_indexes = node->outputIndex;
for (size_t j = 0; j < node_output_indexes.size(); j++) {
tensors_[node_output_indexes[j]].out_nodes_.push_back(i);
all_node_output_tensor_indexes.insert(all_node_output_tensor_indexes.end(), node_output_indexes.begin(),
node_output_indexes.end());
tensors_[node_output_indexes[j]].prev_nodes_.push_back(i);
}
all_node_output_tensor_indexes.insert(all_node_output_tensor_indexes.end(), node_output_indexes.begin(),
node_output_indexes.end());
}
for (uint32_t i = 0; i < tensors_.size(); i++) {
if (tensors_[i].prev_nodes_.empty() || IsContain(graph->inputIndex, i) || !graph->allTensors.at(i)->data.empty()) {
tensors_[i].is_inferred_ = true;
}
}
for (size_t i = 0; i < graph->nodes.size(); i++) {
auto &node = graph->nodes[i];
auto &node = graph->nodes.at(i);
if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
[&](uint32_t index) { return !IsContain(all_node_output_tensor_indexes, index); })) {
[&](uint32_t idx) { return tensors_[idx].is_inferred_; })) {
infer_node_indexes_.push_back(i);
}
}
for (size_t i = 0; i < tensors_.size(); i++) {
if (tensors_[i].out_nodes_.empty()) {
tensors_[i].is_infer_ = true;
}
}
}
void InferShapePass::AddOutputNode(uint32_t infer_node_index) {
auto &node = graph_->nodes[infer_node_index];
void InferShapePass::AddOutputNodes(MetaGraphT *graph, uint32_t infer_node_index) {
auto &node = graph->nodes.at(infer_node_index);
for (size_t i = 0; i < node->outputIndex.size(); i++) {
auto output_tensor_node_indexes = tensors_[node->outputIndex[i]].in_nodes_;
tensors_[node->outputIndex[i]].is_infer_ = true;
for (size_t j = 0; j < output_tensor_node_indexes.size(); j++) {
bool flag = false;
auto &output_tensor_node = graph_->nodes[output_tensor_node_indexes[j]];
for (size_t k = 0; k < output_tensor_node->outputIndex.size(); k++) {
if (graph_->allTensors.at(output_tensor_node->outputIndex[k])->dataType != kObjectTypeTensorType) {
if (graph_->allTensors.at(output_tensor_node->outputIndex[k])->dataType == kTypeUnknown ||
tensors_[output_tensor_node->outputIndex[k]].is_infer_ == false) {
flag = true;
break;
}
} else {
if (tensors_[output_tensor_node->outputIndex[k]].is_infer_ == false) {
flag = true;
break;
}
}
}
if (flag) {
AddNextInferShapeNode(output_tensor_node_indexes, j);
auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_;
for (size_t j = 0; j < next_nodes_indexes.size(); j++) {
auto &next_node = graph->nodes.at(next_nodes_indexes[j]);
if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(),
[&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) {
AddNextInferShapeNode(graph, next_nodes_indexes, j);
}
}
}
}
void InferShapePass::AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index) {
auto &output_tensor_node = graph_->nodes.at(output_tensor_node_indexes[index]);
if (find(infer_node_indexes_.begin(), infer_node_indexes_.end(), output_tensor_node_indexes[index]) ==
void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> next_nodes_indexes, size_t index) {
auto &next_node = graph->nodes.at(next_nodes_indexes[index]);
if (find(infer_node_indexes_.begin(), infer_node_indexes_.end(), next_nodes_indexes[index]) ==
infer_node_indexes_.end()) {
auto output_tensor_node_type = output_tensor_node->primitive->value.type;
if (output_tensor_node_type == schema::PrimitiveType_Merge) {
if (std::all_of(output_tensor_node->inputIndex.begin(),
output_tensor_node->inputIndex.begin() + output_tensor_node->inputIndex.size() / 2,
[&](uint32_t k) { return tensors_[k].is_infer_; }) ||
std::all_of(output_tensor_node->inputIndex.begin() + output_tensor_node->inputIndex.size() / 2,
output_tensor_node->inputIndex.end(), [&](uint32_t k) { return tensors_[k].is_infer_; })) {
infer_node_indexes_.push_back(output_tensor_node_indexes[index]);
}
} else {
bool flag = true;
for (size_t i = 0; i < output_tensor_node->inputIndex.size(); i++) {
if (!(tensors_[output_tensor_node->inputIndex[i]].is_infer_)) {
flag = false;
break;
}
}
if (flag) {
infer_node_indexes_.push_back(output_tensor_node_indexes[index]);
auto next_node_type = next_node->primitive->value.type;
if (next_node_type == schema::PrimitiveType_Merge) {
if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.begin() + next_node->inputIndex.size() / 2,
[&](uint32_t i) { return tensors_[i].is_inferred_; }) ||
std::all_of(next_node->inputIndex.begin() + next_node->inputIndex.size() / 2, next_node->inputIndex.end(),
[&](uint32_t i) { return tensors_[i].is_inferred_; })) {
infer_node_indexes_.push_back(next_nodes_indexes[index]);
}
} else if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
[&](uint32_t i) { return tensors_[i].is_inferred_; }) ||
std::any_of(next_node->inputIndex.begin(), next_node->inputIndex.end(),
[&](uint32_t i) { return graph->allTensors.at(i)->dataType == kObjectTypeTensorType; })) {
infer_node_indexes_.push_back(next_nodes_indexes[index]);
}
}
}

View File

@ -32,9 +32,9 @@ namespace mindspore {
namespace lite {
struct InferTensor {
std::vector<uint32_t> in_nodes_; /* used current tensor as input */
std::vector<uint32_t> out_nodes_;
bool is_infer_;
std::vector<uint32_t> next_nodes_;
std::vector<uint32_t> prev_nodes_;
bool is_inferred_;
};
class InferShapePass : public GraphPass {
@ -45,11 +45,10 @@ class InferShapePass : public GraphPass {
private:
void InitSearchTensor(MetaGraphT *graph);
void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index);
void AddOutputNode(uint32_t infer_node_index);
void AddNextInferShapeNode(MetaGraphT *graph, std::vector<uint32_t> next_nodes_indexes, size_t index);
void AddOutputNodes(MetaGraphT *graph, uint32_t infer_node_index);
lite::converter::FmkType fmk_type_ = FmkType_TF;
MetaGraphT *graph_ = nullptr;
std::vector<InferTensor> tensors_ = {};
std::vector<uint32_t> infer_node_indexes_ = {};
bool infer_interrupt_ = false;