!21635 [MSLITE] Modify output tensor names of ms model.

Merge pull request !21635 from wangshaocong/convert_name
This commit is contained in:
i-robot 2021-08-12 08:58:17 +00:00 committed by Gitee
commit bb406a14b5
25 changed files with 148 additions and 248 deletions

View File

@ -224,7 +224,6 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/multi_head_attention_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/reshape_reshape_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc

View File

@ -1,4 +1,4 @@
hiai_model_0909_kd_rot_ps_softmax.tflite
Modify_Out_hiai_model_0909_kd_rot_ps_softmax.tflite
# hiai_chinese_english_recognize_model_float32.tflite
# hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
# hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
@ -6,44 +6,44 @@ hiai_model_0909_kd_rot_ps_softmax.tflite
# hiai_model_normalize_object_scene_ps_20200519.tflite
# mtk_AADB_HADB_MBV2_model_fp32.tflite
# mtk_AADB_HADB_MBV3_model_fp32.tflite
mobilenet_v1_0.25_128.tflite
mobilenet_v1_0.25_160.tflite
mobilenet_v1_0.25_192.tflite
mobilenet_v1_0.25_224.tflite
mobilenet_v1_0.5_128.tflite
mobilenet_v1_0.5_160.tflite
mobilenet_v1_0.5_192.tflite
mobilenet_v1_0.5_224.tflite
mobilenet_v1_0.75_128.tflite
mobilenet_v1_0.75_160.tflite
mobilenet_v1_0.75_192.tflite
mobilenet_v1_0.75_224.tflite
mobilenet_v1_1.0_128.tflite
mobilenet_v1_1.0_160.tflite
mobilenet_v1_1.0_192.tflite
mobilenet_v1_1.0_224.tflite
mobilenet_v2_1.0_224.tflite
Modify_Out_mobilenet_v1_0.25_128.tflite
Modify_Out_mobilenet_v1_0.25_160.tflite
Modify_Out_mobilenet_v1_0.25_192.tflite
Modify_Out_mobilenet_v1_0.25_224.tflite
Modify_Out_mobilenet_v1_0.5_128.tflite
Modify_Out_mobilenet_v1_0.5_160.tflite
Modify_Out_mobilenet_v1_0.5_192.tflite
Modify_Out_mobilenet_v1_0.5_224.tflite
Modify_Out_mobilenet_v1_0.75_128.tflite
Modify_Out_mobilenet_v1_0.75_160.tflite
Modify_Out_mobilenet_v1_0.75_192.tflite
Modify_Out_mobilenet_v1_0.75_224.tflite
Modify_Out_mobilenet_v1_1.0_128.tflite
Modify_Out_mobilenet_v1_1.0_160.tflite
Modify_Out_mobilenet_v1_1.0_192.tflite
Modify_Out_mobilenet_v1_1.0_224.tflite
Modify_Out_mobilenet_v2_1.0_224.tflite
# mtk_model_normalize_object_scene_ps_20200519_f32.tflite
# mtk_model_ckpt.tflite
mtk_age_gender.tflite
Modify_Out_mtk_age_gender.tflite
# mtk_model_face_dress.tflite
# mtk_face_features_v1.tflite
# densenet.tflite
squeezenet.tflite
Modify_Out_squeezenet.tflite
# resnet_v2_101_299.tflite
# mnasnet_1.3_224.tflite
inception_v3.tflite
Modify_Out_inception_v3.tflite
# deeplabv3_257_mv_gpu.tflite
# multi_person_mobilenet_v1_075_float.tflite
# hiai_vad.tflite
# ide_label_base.tflite
# ide_label_retrained.tflite
ml_ei_headpose.tflite
Modify_Out_ml_ei_headpose.tflite
# ml_ei_landmark.tflite
mnist.tflite
mobilenet.tflite
resnet.tflite
scan_hms_angle1.tflite
Modify_Out_mnist.tflite
Modify_Out_mobilenet.tflite
Modify_Out_resnet.tflite
Modify_Out_scan_hms_angle1.tflite
# scan_hms_detect.tflite
# hiai_latin_ocr.tflite
# hiai_latin_ocr_1.tflite
@ -57,25 +57,25 @@ scan_hms_angle1.tflite
# hiai_ssd_mobilenetv2_object.tflite
# hiai_cv_focusShootOCRModel_02.tflite
# hiai_cv_poseEstimation.tflite
inception_v4.tflite
Modify_Out_inception_v4.tflite
# mtk_model_normalize_object_scene_ps_20200519_f16.tflite
# mtk_age_gender_fp16.tflite
# mtk_model_face_dress_fp16.tflite
mtk_AADB_HADB_MBV2_model_f16.tflite
Modify_Out_mtk_AADB_HADB_MBV2_model_f16.tflite
# mtk_AADB_HADB_MBV3_model_f16.tflite
# mtk_model_emotions_0725_fp16.tflite
# mtk_face_features_v1_fp16.tflite
# siteAI_digcom_AI_ECN.tflite
siteAI_digcom_g2v_keras.tflite
siteAI_trans_nonlinear.tflite
siteAI_trans_tcpclassify.tflite
siteAI_wireless_depress_w.tflite
siteAI_wireless_restore_w.tflite
Modify_Out_siteAI_digcom_g2v_keras.tflite
Modify_Out_siteAI_trans_nonlinear.tflite
Modify_Out_siteAI_trans_tcpclassify.tflite
Modify_Out_siteAI_wireless_depress_w.tflite
Modify_Out_siteAI_wireless_restore_w.tflite
# magenta_arbitrary-image-stylization-v1-256_fp16_prediction_1.tflite
# ml_object_detect.tflite
# ml_object_detect_1.tflite
hiai_cpu_face_emotion.tflite
hiai_cpu_face_gazing.tflite
Modify_Out_hiai_cpu_face_emotion.tflite
Modify_Out_hiai_cpu_face_gazing.tflite
# hiai_cpu_face_headpose.tflite
# hiai_humanDetection.tflite
# hiai_cv_focusShootOCRModel_08.tflite
@ -83,20 +83,20 @@ hiai_cpu_face_gazing.tflite
# hiai_face_model_npu.tflite
# hiai_ctpn_feature_map.tflite
# hiai_cv_labelDetectorModel_v2.tflite
hiai_cv_labelDetectorModel_v4.tflite
Modify_Out_hiai_cv_labelDetectorModel_v4.tflite
# hiai_dress_detect.tflite
# hiai_cv_saliencyDetectorModel.tflite
# hiai_frozen_inference_graph.tflite
# hiai_ghostnet.tflite
# hiai_iMaxDN_RGB.tflite
# hiai_iMaxSR_RGB.tflite
hiai_label_and_video.tflite
Modify_Out_hiai_label_and_video.tflite
# hiai_lm_inference_graph.tflite
efficientnet_lite0_fp32_2.tflite
efficientnet_lite1_fp32_2.tflite
efficientnet_lite2_fp32_2.tflite
efficientnet_lite3_fp32_2.tflite
efficientnet_lite4_fp32_2.tflite
Modify_Out_efficientnet_lite0_fp32_2.tflite
Modify_Out_efficientnet_lite1_fp32_2.tflite
Modify_Out_efficientnet_lite2_fp32_2.tflite
Modify_Out_efficientnet_lite3_fp32_2.tflite
Modify_Out_efficientnet_lite4_fp32_2.tflite
# mnasnet_0.50_224_1_metadata_1.tflite
# mnasnet_0.75_224_1_metadata_1.tflite
# mnasnet_1.0_128_1_metadata_1.tflite
@ -138,7 +138,7 @@ efficientnet_lite4_fp32_2.tflite
# ml_location.tflite
# ml_face_openclose_tflite.tflite
# ml_object_detect_pb2tflite.tflite
Q_AADB_HADB_MBV2_model.tflite
Modify_Out_Q_AADB_HADB_MBV2_model.tflite
# Q_convert.tflite
# Q_crnn_ori_75w_slim_norm_pb2tflite.tflite
# Q_crnn_ori_v2_405001_notrans_nopre_pb2tflite.tflite
@ -169,6 +169,6 @@ Q_AADB_HADB_MBV2_model.tflite
# text_classification.tflite
# Q_detect_fpn_add_inception-1448650.tflite
# Q_hand_0812_pb2tflite.tflite
intent_detect_hi_v2.tflite
simple_IPS_model_4D_input.onnx
rpnt_pdr_conv2d_16_fixed_last.onnx
Modify_Out_intent_detect_hi_v2.tflite
Modify_Out_simple_IPS_model_4D_input.onnx
Modify_Out_rpnt_pdr_conv2d_16_fixed_last.onnx

View File

@ -1 +1 @@
intent_detect_hi_v2.tflite
Modify_Out_intent_detect_hi_v2.tflite

View File

@ -101,7 +101,7 @@ Q_crnn_ori_75w_slim_norm.pb
Q_crnn_ori_v2_405001_notrans_nopre.pb
bolt_segment.pb
ml_location_lane_counter.onnx 2
gts_detect_5k_tf115.tflite
Modify_Out_gts_detect_5k_tf115.tflite
smartreply.tflite
ml_text_correction.tflite
ml_ocr_jk_pb2tflite.tflite
@ -152,7 +152,7 @@ hiai_iMaxDN_RGB.pb
hiai_iMaxSR_RGB.pb
hiai_lm_inference_graph.pb
hiai_PoseEstimation_Pcm.pb
hiai_model_0909_kd_rot_ps_softmax.tflite
Modify_Out_hiai_model_0909_kd_rot_ps_softmax.tflite
hiai_chinese_english_recognize_model_float32.tflite
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite

View File

@ -88,7 +88,7 @@ ml_video_edit_hair_dyeing_segmodel_v2 0.5
ml_video_edit_makeup_mobilenetv203.onnx 2
ml_video_edit_hairline_segmentation;3 0.5
ml_video_edit_hair_dyeing_migrate_v2.onnx;4 0.5
ml_audio_kit_encoder_v5.pb;6;1,32:1,32:1,32:1,32:1:1
Modify_Out_ml_audio_kit_encoder_v5.pb;6;1,32:1,32:1,32:1,32:1:1
fsr_270_mindspore.pb 1
fsr_360_mindspore.pb 1
fsr_720_mindspore.pb 1

View File

@ -91,7 +91,7 @@ tacotron_encoder_stf.pb;5;1:1,62:1,62:1,62:1,62;;input_dependent
female_model_step2_int16_noiseout.pb;66
ml_female_model_step6_noiseout.pb;66
ml_male_model_step6_noiseout.pb;66
ml_tts_decoder_control_flow.pb;5
Modify_Out_ml_tts_decoder_control_flow.pb;5
ml_tts_decoder.pb;5
ml_tts_encoder_control_flow.pb;4;1:1,22:1:1;;input_dependent
ml_tts_vocoder.pb;66
@ -99,7 +99,7 @@ hiai_nlu_model.pb;3;1,16:1,16:1,16
gts_object_detect_Ics.pb;1;420,630,3;;input_dependent
hiai_transformer_encoder.pb;15
decoder_step_nocumsum_v5.pb;13;1:1,512:1,1429,2:1,127:1,127:1,127:1,127,320:1,80:1,512:1,512:1,512:1,512:1,512
ml_audio_kit_encoder_v5.pb;6;1,32:1,32:1,32:1,32:1:1
Modify_Out_ml_audio_kit_encoder_v5.pb;6;1,32:1,32:1,32:1,32:1:1
hiai_nlu_model_v1.pb;3;1,16:1,16:1,16 2.0
hiai_nlu_model_v2.pb;7;1,5:1,6:1,174:1,98:1,5:1,5:1,5
hiai_nlu_model_multi.pb;6;1,32:1,32:1,6:1,11:1,74:1,32

View File

@ -81,7 +81,7 @@ ml_video_edit_oneclick_adaptis.pb;3 6
ml_female_model_step6_noiseout.pb;66 2
ml_male_model_step6_noiseout.pb;66 2.5
ml_tts_encoder_control_flow.pb;4;1:1,22:1:1 1.5
ml_tts_decoder_control_flow.pb;5 1
Modify_Out_ml_tts_decoder_control_flow.pb;5 1
ml_tts_decoder.pb;5 2.5
ml_tts_vocoder.pb;66 53
hiai_transformer_encoder.pb;15 4

View File

@ -1,11 +1,11 @@
hiai_model_0909_kd_rot_ps_softmax.tflite
hiai_chinese_english_recognize_model_float32.tflite
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
hiai_cn_recognize_modify_padv2.tflite
hiai_model_normalize_object_scene_ps_20200519.tflite
mtk_AADB_HADB_MBV2_model_fp32.tflite
mtk_AADB_HADB_MBV3_model_fp32.tflite
Modify_Out_hiai_model_0909_kd_rot_ps_softmax.tflite
Modify_Out_hiai_chinese_english_recognize_model_float32.tflite
Modify_Out_hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
Modify_Out_hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
Modify_Out_hiai_cn_recognize_modify_padv2.tflite
Modify_Out_hiai_model_normalize_object_scene_ps_20200519.tflite
#mtk_AADB_HADB_MBV2_model_fp32.tflite The model has three outputs, but the benchmark file has one output only.
Modify_Out_mtk_AADB_HADB_MBV3_model_fp32.tflite
mobilenet_v1_0.25_128.tflite
mobilenet_v1_0.25_160.tflite
mobilenet_v1_0.25_192.tflite
@ -112,7 +112,7 @@ lite-model_deeplabv3-mobilenetv2-float16_1_default_1.tflite
lite-model_east-text-detector_fp16_1.tflite
lite-model_cartoongan_fp16_1.tflite
lite-model_arbitrary-image-stylization-inceptionv3_fp16_predict_1.tflite
gts_detect_5k_tf115.tflite
Modify_Out_gts_detect_5k_tf115.tflite
mtk_isface.tflite
mtk_landmark.tflite
mtk_new_detect.tflite
@ -182,13 +182,13 @@ Q_iMaxSR_RGB_385_p_pb2tflite.tflite
bloom_new_detect.tflite
bloom_model_age_gender.tflite
bloom_isface.tflite
hiai_object_detect_814.tflite
hiai_object_tflite_graph_8bit.tflite
Modify_Out_hiai_object_detect_814.tflite
Modify_Out_hiai_object_tflite_graph_8bit.tflite
lma_tsec_shallow_channels16_ds2.1.1_model-best-f1.tflite
lite-model_arbitrary-image-stylization-inceptionv3_fp16_transfer_1.tflite;2
magenta_arbitrary-image-stylization-v1-256_fp16_transfer_1.tflite;2
albert_lite_base_squadv1_1.tflite;3
mobilebert_1_default_1.tflite;3
Modify_Out_albert_lite_base_squadv1_1.tflite;3
Modify_Out_mobilebert_1_default_1.tflite;3
ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2
hdc_tb_cn_neg.tflite;3
@ -196,7 +196,7 @@ hiai_cv_labelDetectorModel_v3.tflite;2
ml_tacotron_decoder_step_stf.tflite;9;1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1
ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16
ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16
lite-model_albert_lite_base_squadv1_metadata_1.tflite;3
lite-model_mobilebert_1_metadata_1.tflite;3
Modify_Out_lite-model_albert_lite_base_squadv1_metadata_1.tflite;3
Modify_Out_lite-model_mobilebert_1_metadata_1.tflite;3
hiai_vad.tflite;2
add_uint8.tflite;2

View File

@ -33,9 +33,9 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V
lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite
lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite
vision_classifier_fungi_mobile_V1_1_default_1.tflite
detect.tflite
ssd_mobilenet_v1_1_default_1.tflite
object_detection_mobile_object_localizer_v1_1_default_1.tflite
gts_detect_0730_quant_frozen.tflite
Modify_Out_detect.tflite
Modify_Out_ssd_mobilenet_v1_1_default_1.tflite
Modify_Out_object_detection_mobile_object_localizer_v1_1_default_1.tflite
Modify_Out_gts_detect_0730_quant_frozen.tflite
gts_model_quant_frozen.tflite
inception_v2_224_quant.tflite

View File

@ -1,7 +1,7 @@
# [first column]:model_name;input_num;input_shape;threads;extra_info. If there is no need to set these parameters, the
# content after ";" can be omitted.
# [second column]:accuracy limit for float16 in arm64 device
hiai_model_0909_kd_rot_ps_softmax.tflite 10
Modify_Out_hiai_model_0909_kd_rot_ps_softmax.tflite 10
hiai_chinese_english_recognize_model_float32.tflite 13
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite 10
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite 10
@ -121,7 +121,7 @@ lite-model_deeplabv3-mobilenetv2-float16_1_default_1.tflite 60
lite-model_east-text-detector_fp16_1.tflite 60
lite-model_cartoongan_fp16_1.tflite 3
lite-model_arbitrary-image-stylization-inceptionv3_fp16_predict_1.tflite 6
gts_detect_5k_tf115.tflite 9.5
Modify_Out_gts_detect_5k_tf115.tflite 9.5
mtk_isface.tflite 0.2
mtk_landmark.tflite 0.3
mtk_new_detect.tflite 3
@ -212,7 +212,7 @@ bloom_model_age_gender.tflite 0.5
bloom_isface.tflite 0.5
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.
hiai_object_detect_814.tflite 14
Modify_Out_hiai_object_detect_814.tflite 14
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 11
ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5
hdc_tb_cn_neg.tflite;3 295

View File

@ -82,7 +82,7 @@ TEST_F(GraphTest, UserSetGraphOutput1) {
string name = out_data.first;
void *data = out_data.second;
float *fp32_data = reinterpret_cast<float *>(data);
if (name == "Stack-8") {
if (name == "output") {
output_count++;
ASSERT_LE(fabs(fp32_data[0] - (0.115831)), 0.01);
ASSERT_LE(fabs(fp32_data[1] - (0.113074)), 0.01);
@ -90,7 +90,7 @@ TEST_F(GraphTest, UserSetGraphOutput1) {
ASSERT_LE(fabs(fp32_data[3] - (0.346307)), 0.01);
ASSERT_LE(fabs(fp32_data[4] - (-0.15687)), 0.01);
}
if (name == "Stack-10") {
if (name == "output2") {
output_count++;
ASSERT_LE(fabs(fp32_data[0] - (0.06387864)), 0.01);
ASSERT_LE(fabs(fp32_data[1] - (0.22883008)), 0.01);
@ -98,7 +98,7 @@ TEST_F(GraphTest, UserSetGraphOutput1) {
ASSERT_LE(fabs(fp32_data[3] - (0.04586578)), 0.01);
ASSERT_LE(fabs(fp32_data[4] - (0.06820235)), 0.01);
}
if (name == "Stack-13") {
if (name == "output3") {
output_count++;
ASSERT_LE(fabs(fp32_data[0] - (-0.1617176)), 0.01);
ASSERT_LE(fabs(fp32_data[1] - (-0.3828573)), 0.01);

View File

@ -532,8 +532,12 @@ int AnfExporter::SetMetaGraphOutput(const FuncGraphPtr &func_graph,
auto &final_meta_graph = meta_graphT->subGraph.at(final_meta_graph_index);
meta_graphT->outputIndex.assign(final_meta_graph->outputIndices.begin(), final_meta_graph->outputIndices.end());
for (auto &output_index : meta_graphT->outputIndex) {
auto &tensor = meta_graphT->allTensors.at(output_index);
// set output tensor names to the original names, the output_names is null in nnie converter.
auto output_names = ConverterContext::GetInstance()->GetGraphOutputTensorNames();
MS_ASSERT(output_names.size() == meta_graphT->outputIndex.size());
for (size_t idx = 0; idx < output_names.size(); idx++) {
auto &tensor = meta_graphT->allTensors.at(meta_graphT->outputIndex.at(idx));
tensor->name = output_names.at(idx);
ConverterContext::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType);
}

View File

@ -335,5 +335,23 @@ size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_
return 1;
}
}
STATUS GetInputIndexOfTupleGetItem(const AnfNodePtr &node, int *index) {
MS_ASSERT(node != nullptr);
if (!opt::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
MS_LOG(ERROR) << "The node is not a tupleGetItem node.";
return RET_ERROR;
}
auto index_vnode = node->cast<CNodePtr>()->input(2);
if (!utils::isa<ValueNodePtr>(index_vnode)) {
MS_LOG(ERROR) << "TupleGetItem's input 2 is not a valueNode.";
return RET_ERROR;
}
auto value_node = utils::cast<ValueNodePtr>(index_vnode);
MS_ASSERT(value_node != nullptr);
*index = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
: GetValue<int>(value_node->value());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -413,6 +413,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type)
STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false);
STATUS GetInputIndexOfTupleGetItem(const AnfNodePtr &node, int *index);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H

View File

@ -62,7 +62,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_bn_fusion.cc
../optimizer/fusion/conv_tuplegetitem_fusion.cc
../optimizer/fusion/constant_folding_fusion.cc
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/fusion/norm_fusion.cc
../optimizer/fusion/batchmatmul_fusion.cc
../optimizer/fusion/sigmoid_mul_fusion.cc

View File

@ -106,6 +106,11 @@ class ConverterContext {
}
size_t GetGraphInputTensorShapeMapSize() { return graph_input_tensor_shape_map_.size(); }
void SetGraphOutputTensorNames(const std::vector<std::string> &output_names) {
graph_output_tensor_names_ = output_names;
}
const std::vector<std::string> GetGraphOutputTensorNames() { return graph_output_tensor_names_; }
private:
ConverterContext() {}
virtual ~ConverterContext() = default;
@ -113,6 +118,7 @@ class ConverterContext {
std::map<int32_t, int32_t> graph_input_data_type_map_;
std::map<int32_t, int32_t> graph_output_data_type_map_;
std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_;
std::vector<std::string> graph_output_tensor_names_;
};
} // namespace lite
} // namespace mindspore

View File

@ -53,8 +53,9 @@ STATUS CaffeInspector::FindGraphInputsAndOutputs() {
}
}
for (const auto &iter : layerTops) {
if (layerBottoms.find(iter) == layerBottoms.end()) {
graphOutput.insert(iter);
if (layerBottoms.find(iter) == layerBottoms.end() &&
std::find(graphOutput.begin(), graphOutput.end(), iter) == graphOutput.end()) {
graphOutput.push_back(iter);
}
}
return RET_OK;

View File

@ -21,6 +21,7 @@
#include <string>
#include <unordered_map>
#include <memory>
#include <vector>
#include "proto/caffe.pb.h"
#include "include/errorcode.h"
@ -37,7 +38,7 @@ class CaffeInspector {
STATUS SetLayerTopsAndBottoms();
std::set<std::string> GetGraphInput() { return graphInput; }
std::set<std::string> GetGraphOutput() { return graphOutput; }
std::vector<std::string> GetGraphOutput() { return graphOutput; }
private:
caffe::NetParameter net;
@ -46,7 +47,7 @@ class CaffeInspector {
std::set<std::string> layerBottoms;
std::set<std::string> graphInput;
std::set<std::string> graphOutput;
std::vector<std::string> graphOutput;
};
using CaffeInspectorPtr = std::shared_ptr<CaffeInspector>;

View File

@ -385,11 +385,11 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
}
auto valueNode = NewValueNode(returnPrim);
std::vector<AnfNodePtr> opInputs{valueNode};
if (nodes_.find(*caffeInspector.GetGraphOutput().begin()) == nodes_.end()) {
if (nodes_.find(caffeInspector.GetGraphOutput().front()) == nodes_.end()) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
auto cnode = nodes_.find(*caffeInspector.GetGraphOutput().begin())->second;
auto cnode = nodes_.find(caffeInspector.GetGraphOutput().front())->second;
if (cnode == nullptr) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
@ -399,6 +399,8 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
returnCnode->set_fullname_with_scope("Return");
res_graph_->set_return(returnCnode);
}
// save original output tensor names.
ConverterContext::GetInstance()->SetGraphOutputTensorNames(caffeInspector.GetGraphOutput());
return RET_OK;
}

View File

@ -157,6 +157,13 @@ STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, con
MS_LOG(ERROR) << "convert graph outputs failed.";
return RET_ERROR;
}
// save original output tensor names.
if (root_node_name == "root_node") {
std::vector<std::string> output_names;
std::transform(onnx_graph.output().begin(), onnx_graph.output().end(), std::back_inserter(output_names),
[](auto &graph_output) { return graph_output.name(); });
ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_names);
}
return status;
}
STATUS OnnxModelParser::ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,

View File

@ -33,6 +33,7 @@
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/common/tensor_util.h"
#include "tools/common/node_util.h"
#include "tools/converter/parser/unify_format.h"
using mindspore::lite::converter::FmkType_TF;
@ -1054,7 +1055,18 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
return RET_ERROR;
}
output_nodes.push_back(anf_node);
graph_output_names_.push_back(anf_node->fullname_with_scope());
auto name = anf_node->fullname_with_scope();
if (utils::isa<CNodePtr>(anf_node) && opt::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
int index = 0;
if (GetInputIndexOfTupleGetItem(anf_node, &index) != RET_OK) {
MS_LOG(ERROR) << "Get input index of tupleGetItem failed.";
return RET_ERROR;
}
auto in_node = anf_node->cast<CNodePtr>()->input(1);
MS_ASSERT(in_node != nullptr);
name = in_node->fullname_with_scope() + ":" + std::to_string(index);
}
graph_output_names_.push_back(name);
}
}
}
@ -1063,6 +1075,8 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
MS_LOG(ERROR) << "make anf graph outputs node error";
return status;
}
// save original output tensor names.
ConverterContext::GetInstance()->SetGraphOutputTensorNames(graph_output_names_);
return RET_OK;
}
STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph) {

View File

@ -398,6 +398,12 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
returnCnode->set_fullname_with_scope("Return");
res_graph_->set_return(returnCnode);
}
// save original output tensor names.
std::vector<std::string> output_names;
auto output_idx = tflite_subgraph->outputs;
std::transform(output_idx.begin(), output_idx.end(), std::back_inserter(output_names),
[&](auto out_idx) { return tflite_subgraph->tensors.at(out_idx)->name; });
ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_names);
return RET_OK;
}

View File

@ -1,77 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/pooling_activation_fusion.h"
#include <memory>
#include "src/ops/pooling.h"
#include "src/ops/activation.h"
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
constexpr size_t kActivationInputsLength = 2;
}
const BaseRef PoolingActivationFusion::DefinePattern() const {
auto pooling_var = std::make_shared<CondVar>(IsPoolingNode);
auto prim = new (std::nothrow) schema::PrimitiveT();
if (prim == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return nullptr;
}
prim->value.type = primitive_type;
auto prim_value = std::make_shared<lite::PrimitiveC>(prim);
return VectorRef({prim_value, pooling_var});
}
const AnfNodePtr PoolingActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_LOG(DEBUG) << "pooling activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
CheckIfFuncGraphIsNull(func_graph);
CheckIfAnfNodeIsNull(node);
auto act_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(act_node);
CheckInputSize(act_node, kActivationInputsLength);
auto primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(act_node->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
MS_ASSERT(act_primitivec != nullptr);
if (act_primitivec->GetType() != activation_type) {
return node;
}
AnfNodePtr pre_node = act_node->input(1);
CheckIfAnfNodeIsNull(pre_node);
if (pre_node != nullptr && pre_node->isa<CNode>()) {
if (IsMultiOutputTensors(func_graph, pre_node)) {
return node;
}
auto pooling_node = pre_node->cast<CNodePtr>();
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(pooling_node->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Pooling>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
primc->SetActivationType(activation_type);
return pre_node;
}
}
return node;
}
} // namespace mindspore::opt

View File

@ -1,48 +0,0 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include <memory>
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
constexpr size_t kActivationInputsLength = 2;
}
const BaseRef QuantDtypeCastFusion::DefinePattern() const {
auto quant_var = std::make_shared<CondVar>(IsQuantNode);
auto input_var = std::make_shared<Var>();
return VectorRef({quant_var, input_var});
}
const AnfNodePtr QuantDtypeCastFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_LOG(DEBUG) << "quant dtype cast fusion pass process";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
auto act_node = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(act_node) != lite::RET_OK ||
CheckInputSize(act_node, kActivationInputsLength) != lite::RET_OK) {
return nullptr;
}
AnfNodePtr pre_node = act_node->input(1);
if (CheckIfAnfNodeIsNull(pre_node) != lite::RET_OK) {
return nullptr;
}
return pre_node;
}
} // namespace mindspore::opt

View File

@ -1,34 +0,0 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_QUANT_DTYPE_CAST_FUSION_H
#define LITE_QUANT_DTYPE_CAST_FUSION_H
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class QuantDtypeCastFusion : public PatternProcessPass {
public:
explicit QuantDtypeCastFusion(bool multigraph = true, const std::string &name = "quant_dtype_cast_fusion")
: PatternProcessPass(name, multigraph) {}
~QuantDtypeCastFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // LITE_QUANT_DTYPE_CAST_FUSION_H