forked from mindspore-Ecosystem/mindspore
[MSLITE] Fix bug of tf model‘s output names.
This commit is contained in:
parent
db19c60581
commit
f0828412bb
|
@ -167,8 +167,7 @@ int NPUInsertTransformPass::InsertNode(NPUOp *op, NPUOp *post_op, size_t post_in
|
|||
} else {
|
||||
// post_op nullptr mean output, we remain graph output tensor name unchanged
|
||||
auto graph_output_name = in_tensor.Name();
|
||||
in_tensor.SetTensorName(graph_output_name + "_before_" + name_);
|
||||
nc2nh_tensor->SetTensorName(graph_output_name);
|
||||
nc2nh_tensor->SetTensorName(graph_output_name + "_after_" + name_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -16,5 +16,5 @@ deconvs_model
|
|||
# onnx
|
||||
ml_2012_ocr_cn.onnx
|
||||
# aware_training
|
||||
Modify_Out_video_infer2.tflite;;;;aware_training
|
||||
video_infer2.tflite;;;;aware_training
|
||||
mobilenet_v1_1.0_224_quant.tflite;;;;aware_training
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
mobilenet_v1_1.0_224.tflite
|
||||
mobilenet_v2_1.0_224.tflite
|
||||
Modify_Out_mtk_age_gender_fp16.tflite
|
||||
mtk_age_gender_fp16.tflite
|
||||
mtk_isface.tflite
|
||||
mtk_landmark.tflite
|
||||
mtk_new_detect.tflite
|
||||
|
@ -25,7 +25,7 @@ Q_face_recognition.onnx
|
|||
Q888_iris_detect.onnx
|
||||
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite
|
||||
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite
|
||||
Q_detect_fpn_add_inception-1448650.tflite
|
||||
Q888_face_dress_mv3y.tflite
|
||||
Q888_HADB_AADB_MBV2_model_fp32.tflite
|
||||
Q888_landmark.tflite
|
||||
|
@ -58,4 +58,4 @@ mtk_detect_mbv1_640_480_nopostprocess_simplified
|
|||
mtk_model_normalize_object_scene_ps_20200519_f16.tflite
|
||||
mtk_AADB_HADB_MBV2_model_f16.tflite
|
||||
mtk_model_emotions_0725_fp16.tflite
|
||||
Modify_Out_Q888_age_gender_orderd.tflite
|
||||
Q888_age_gender_orderd.tflite
|
||||
|
|
|
@ -6,10 +6,10 @@ resnet.tflite
|
|||
squeezenet.tflite
|
||||
mtk_AADB_HADB_MBV2_model_fp32.tflite
|
||||
hiai_cn_recognize_modify_padv2.tflite
|
||||
Modify_Out_hiai_cv_focusShootOCRModel_08.tflite
|
||||
hiai_cv_focusShootOCRModel_08.tflite
|
||||
hiai_model_normalize_object_scene_ps_20200519.tflite
|
||||
inception_v3.tflite
|
||||
Modify_Out_mtk_age_gender_fp16.tflite
|
||||
mtk_age_gender_fp16.tflite
|
||||
mtk_isface.tflite
|
||||
mtk_landmark.tflite
|
||||
mtk_new_detect.tflite
|
||||
|
@ -102,7 +102,7 @@ Q_crnn_ori_v2_405001_notrans_nopre.pb
|
|||
bolt_segment.pb
|
||||
ml_location_lane_counter.onnx 2
|
||||
gts_detect_5k_tf115.tflite
|
||||
Modify_Out_smartreply.tflite
|
||||
smartreply.tflite
|
||||
ml_text_correction.tflite
|
||||
ml_ocr_jk_pb2tflite.tflite
|
||||
scan_hms_angle_pb2tflite.tflite
|
||||
|
@ -111,21 +111,21 @@ ml_face_openclose_tflite.tflite
|
|||
unet_mbv2_05_104pts.tflite
|
||||
hiai_AADB_HADB_MBV2_model_f16.tflite
|
||||
hiai_AADB_HADB_MBV2_model_fp32.tflite
|
||||
Modify_Out_hiai_detect_curve_model_float32.tflite
|
||||
Modify_Out_hiai_detectmodel_06_23_960_480_1180700.tflite
|
||||
Modify_Out_lite-model_aiy_vision_classifier_food_V1_1.tflite
|
||||
hiai_detect_curve_model_float32.tflite
|
||||
hiai_detectmodel_06_23_960_480_1180700.tflite
|
||||
lite-model_aiy_vision_classifier_food_V1_1.tflite
|
||||
lite-model_disease-classification_1.tflite
|
||||
lite-model_models_mushroom-identification_v1_1.tflite
|
||||
Modify_Out_smartreply_1_default_1.tflite
|
||||
smartreply_1_default_1.tflite
|
||||
text_classification.tflite
|
||||
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite
|
||||
Modify_Out_Q_hand_0812_pb2tflite.tflite
|
||||
Q_detect_fpn_add_inception-1448650.tflite
|
||||
Q_hand_0812_pb2tflite.tflite
|
||||
bloom_landmark.tflite
|
||||
Q888_face_dress_mv3y.tflite
|
||||
Q888_HADB_AADB_MBV2_model_fp32.tflite
|
||||
Q888_landmark.tflite
|
||||
Q888_pose.tflite
|
||||
Modify_Out_Q888_lapa158_unet_0924.tflite
|
||||
Q888_lapa158_unet_0924.tflite
|
||||
Q888_isface.tflite
|
||||
Q888_new_detect.tflite
|
||||
Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
|
||||
|
@ -156,7 +156,7 @@ hiai_model_0909_kd_rot_ps_softmax.tflite
|
|||
hiai_chinese_english_recognize_model_float32.tflite
|
||||
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
|
||||
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
|
||||
Modify_Out_hiai_detectmodel_desnet_256_128_64_32.tflite
|
||||
hiai_detectmodel_desnet_256_128_64_32.tflite
|
||||
mtk_AADB_HADB_MBV3_model_fp32.tflite
|
||||
Q888_face_recognition.onnx
|
||||
mobilenet_v1_0.25_128.tflite
|
||||
|
@ -201,7 +201,7 @@ mtk_AADB_HADB_MBV2_model_f16.tflite
|
|||
mtk_AADB_HADB_MBV3_model_f16.tflite
|
||||
mtk_model_emotions_0725_fp16.tflite
|
||||
mtk_face_features_v1_fp16.tflite
|
||||
Modify_Out_Q888_age_gender_orderd.tflite
|
||||
Q888_age_gender_orderd.tflite
|
||||
emotion
|
||||
gender_res_large_deploy
|
||||
glasses
|
||||
|
|
|
@ -38,7 +38,7 @@ multi_person_mobilenet_v1_075_float.tflite
|
|||
ide_label_base.tflite
|
||||
ide_label_retrained.tflite
|
||||
ml_ei_headpose.tflite
|
||||
Modify_Out_ml_ei_landmark.tflite
|
||||
ml_ei_landmark.tflite
|
||||
mnist.tflite
|
||||
mobilenet.tflite
|
||||
resnet.tflite
|
||||
|
@ -58,7 +58,7 @@ hiai_cv_focusShootOCRModel_02.tflite
|
|||
hiai_cv_poseEstimation.tflite
|
||||
inception_v4.tflite
|
||||
mtk_model_normalize_object_scene_ps_20200519_f16.tflite
|
||||
Modify_Out_mtk_age_gender_fp16.tflite
|
||||
mtk_age_gender_fp16.tflite
|
||||
mtk_model_face_dress_fp16.tflite
|
||||
mtk_AADB_HADB_MBV2_model_f16.tflite
|
||||
mtk_AADB_HADB_MBV3_model_f16.tflite
|
||||
|
@ -77,7 +77,7 @@ hiai_cpu_face_emotion.tflite
|
|||
hiai_cpu_face_gazing.tflite
|
||||
hiai_cpu_face_headpose.tflite
|
||||
hiai_humanDetection.tflite
|
||||
Modify_Out_hiai_cv_focusShootOCRModel_08.tflite
|
||||
hiai_cv_focusShootOCRModel_08.tflite
|
||||
ml_face_openclose.tflite
|
||||
hiai_face_model_npu.tflite
|
||||
hiai_ctpn_feature_map.tflite
|
||||
|
@ -122,13 +122,13 @@ mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
|
|||
mtk_276landmark_0913.tflite
|
||||
mtk_face_recognition.tflite
|
||||
mtk_convert_model.tflite
|
||||
Modify_Out_smartreply.tflite
|
||||
smartreply.tflite
|
||||
mindspore_text_classification_tflite.tflite
|
||||
# ml_location.tflite
|
||||
ml_text_correction.tflite
|
||||
ml_pic_shopping.tflite
|
||||
ml_vision_guide_detection3_pb2tflite.tflite
|
||||
Modify_Out_ml_vision_guide_detection1_pb2tflite.tflite
|
||||
ml_vision_guide_detection1_pb2tflite.tflite
|
||||
ml_pic_shopping_pb2tflite.tflite
|
||||
ml_ocr_jk_pb2tflite.tflite
|
||||
ml_ocr_latin_pb2tflite.tflite
|
||||
|
@ -152,27 +152,27 @@ Q_language_model_hrmini_Q4_b4_17w.tflite
|
|||
Q_new_detect.tflite
|
||||
Q_object_scene.tflite
|
||||
Q_pose.tflite
|
||||
Modify_Out_ml_ei_landmark_pb2tflite.tflite
|
||||
ml_ei_landmark_pb2tflite.tflite
|
||||
unet_mbv2_05_104pts.tflite
|
||||
hiai_AADB_HADB_MBV2_model_f16.tflite
|
||||
hiai_AADB_HADB_MBV2_model_fp32.tflite
|
||||
Modify_Out_hiai_detect_curve_model_float32.tflite
|
||||
Modify_Out_hiai_detectmodel_06_23_960_480_1180700.tflite
|
||||
Modify_Out_hiai_detectmodel_desnet_256_128_64_32.tflite
|
||||
Modify_Out_lite-model_aiy_vision_classifier_food_V1_1.tflite
|
||||
hiai_detect_curve_model_float32.tflite
|
||||
hiai_detectmodel_06_23_960_480_1180700.tflite
|
||||
hiai_detectmodel_desnet_256_128_64_32.tflite
|
||||
lite-model_aiy_vision_classifier_food_V1_1.tflite
|
||||
lite-model_disease-classification_1.tflite
|
||||
lite-model_models_mushroom-identification_v1_1.tflite
|
||||
Modify_Out_smartreply_1_default_1.tflite
|
||||
smartreply_1_default_1.tflite
|
||||
text_classification.tflite
|
||||
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite
|
||||
Modify_Out_Q_hand_0812_pb2tflite.tflite
|
||||
Q_detect_fpn_add_inception-1448650.tflite
|
||||
Q_hand_0812_pb2tflite.tflite
|
||||
bloom_landmark.tflite
|
||||
Modify_Out_Q888_age_gender_orderd.tflite
|
||||
Q888_age_gender_orderd.tflite
|
||||
Q888_face_dress_mv3y.tflite
|
||||
Q888_HADB_AADB_MBV2_model_fp32.tflite
|
||||
Q888_landmark.tflite
|
||||
Q888_pose.tflite
|
||||
Modify_Out_Q888_lapa158_unet_0924.tflite
|
||||
Q888_lapa158_unet_0924.tflite
|
||||
Q888_isface.tflite
|
||||
Q888_new_detect.tflite
|
||||
Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
|
||||
|
@ -180,7 +180,7 @@ Q888_face_emo_dress_mv3_orderd.tflite
|
|||
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite
|
||||
bloom_new_detect.tflite
|
||||
Modify_Out_bloom_model_age_gender.tflite
|
||||
bloom_model_age_gender.tflite
|
||||
bloom_isface.tflite
|
||||
hiai_object_detect_814.tflite
|
||||
hiai_object_tflite_graph_8bit.tflite
|
||||
|
@ -193,10 +193,10 @@ ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2:2,1
|
|||
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2:2,1
|
||||
hdc_tb_cn_neg.tflite;3:3,1,2 0.5
|
||||
hiai_cv_labelDetectorModel_v3.tflite;2:2,1
|
||||
Modify_Out_ml_tacotron_decoder_step_stf.tflite;9;1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1
|
||||
ml_tacotron_decoder_step_stf.tflite;9;1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1
|
||||
ml_headpose_pb2tflite.tflite;3:2,3,1;1,64,64,3:16:16
|
||||
ml_ei_headpose_pb2tflite.tflite;3:2,3,1;1,64,64,3:16:16
|
||||
lite-model_albert_lite_base_squadv1_metadata_1.tflite;3:2,3,1
|
||||
lite-model_mobilebert_1_metadata_1.tflite;3
|
||||
Modify_Out_hiai_vad.tflite;2
|
||||
hiai_vad.tflite;2
|
||||
add_uint8.tflite;2
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
Modify_Out_video_infer2.tflite
|
||||
video_infer2.tflite
|
||||
mobilenet_v1_0.25_128_quant.tflite
|
||||
mobilenet_v1_0.25_160_quant.tflite
|
||||
mobilenet_v1_0.25_192_quant.tflite
|
||||
|
|
|
@ -49,7 +49,7 @@ ide_label_base.tflite 22
|
|||
# dividing 0 in the following operator.
|
||||
#ide_label_retrained.tflite
|
||||
ml_ei_headpose.tflite 3
|
||||
Modify_Out_ml_ei_landmark.tflite 3
|
||||
ml_ei_landmark.tflite 3
|
||||
mnist.tflite 4
|
||||
mobilenet.tflite 0.1
|
||||
resnet.tflite 120
|
||||
|
@ -131,7 +131,7 @@ mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 22
|
|||
mtk_276landmark_0913.tflite 16
|
||||
mtk_face_recognition.tflite 8
|
||||
mtk_convert_model.tflite 5
|
||||
Modify_Out_smartreply.tflite 0.1
|
||||
smartreply.tflite 0.1
|
||||
mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias
|
||||
#ml_location.tflite 0.1
|
||||
ml_text_correction.tflite 1
|
||||
|
@ -141,7 +141,7 @@ ml_text_correction.tflite 1
|
|||
# fp16: 27.6 - 27.4 = 0.2
|
||||
#ml_pic_shopping.tflite 0.1
|
||||
ml_vision_guide_detection3_pb2tflite.tflite 0.5
|
||||
Modify_Out_ml_vision_guide_detection1_pb2tflite.tflite 0.5
|
||||
ml_vision_guide_detection1_pb2tflite.tflite 0.5
|
||||
ml_pic_shopping_pb2tflite.tflite 95
|
||||
ml_ocr_jk_pb2tflite.tflite 0.5
|
||||
ml_ocr_latin_pb2tflite.tflite 11.5
|
||||
|
@ -158,17 +158,17 @@ lite-model_on_device_vision_classifier_landmarks_classifier_asia_V1_1.tflite 25
|
|||
lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V1_1.tflite 11
|
||||
lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite 32
|
||||
lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite 14
|
||||
Modify_Out_ml_ei_landmark_pb2tflite.tflite 2
|
||||
ml_ei_landmark_pb2tflite.tflite 2
|
||||
unet_mbv2_05_104pts.tflite 17
|
||||
hiai_AADB_HADB_MBV2_model_f16.tflite 3.5
|
||||
hiai_AADB_HADB_MBV2_model_fp32.tflite 4.5
|
||||
Modify_Out_mtk_age_gender_fp16.tflite 26
|
||||
Modify_Out_hiai_detect_curve_model_float32.tflite 9
|
||||
mtk_age_gender_fp16.tflite 26
|
||||
hiai_detect_curve_model_float32.tflite 9
|
||||
Q_language_model_hrmini_Q4_b4_17w.tflite 3.5
|
||||
Modify_Out_lite-model_aiy_vision_classifier_food_V1_1.tflite 47.5
|
||||
lite-model_aiy_vision_classifier_food_V1_1.tflite 47.5
|
||||
lite-model_disease-classification_1.tflite 70
|
||||
lite-model_models_mushroom-identification_v1_1.tflite 5
|
||||
Modify_Out_smartreply_1_default_1.tflite 0.5
|
||||
smartreply_1_default_1.tflite 0.5
|
||||
text_classification.tflite 0.5
|
||||
Q_AADB_HADB_MBV2_model.tflite 5
|
||||
# the input of Q_convert model is between 0-255
|
||||
|
@ -190,16 +190,16 @@ Q_new_detect.tflite 3.5
|
|||
# the input of Q_object_scene model is between 0-255
|
||||
Q_object_scene.tflite 3
|
||||
Q_pose.tflite 4.1
|
||||
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite 1
|
||||
Q_detect_fpn_add_inception-1448650.tflite 1
|
||||
bloom_landmark.tflite 0.5
|
||||
# input data: 0~255
|
||||
Modify_Out_Q888_age_gender_orderd.tflite 1.5
|
||||
Q888_age_gender_orderd.tflite 1.5
|
||||
Q888_face_dress_mv3y.tflite 0.5
|
||||
Q888_HADB_AADB_MBV2_model_fp32.tflite 2.5
|
||||
Q888_landmark.tflite 0.5
|
||||
Q888_pose.tflite 6.1
|
||||
# the output contains value less than e-7
|
||||
Modify_Out_Q888_lapa158_unet_0924.tflite 19
|
||||
Q888_lapa158_unet_0924.tflite 19
|
||||
Q888_isface.tflite 1.0
|
||||
Q888_new_detect.tflite 1.5
|
||||
Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2
|
||||
|
@ -208,7 +208,7 @@ Q888_face_emo_dress_mv3_orderd.tflite 2.5
|
|||
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1
|
||||
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5
|
||||
bloom_new_detect.tflite 3.5
|
||||
Modify_Out_bloom_model_age_gender.tflite 0.5
|
||||
bloom_model_age_gender.tflite 0.5
|
||||
bloom_isface.tflite 0.5
|
||||
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
|
||||
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
Stack-8 5 1 1 8 4 2
|
||||
output 5 1 1 8 4 2
|
||||
0.115831 0.11307496 0.24593274 0.34630755 -0.156871 0.21111916 -0.1046219 0.01590158 0.2745127 0.17317073 0.1787783 0.36557162 -0.13658395 0.2911819 -0.17356569 0.06825469 0.30655888 0.29681587 0.0078597255 0.3846875 -0.09266291 0.26170188 -0.15063931 0.04322962 0.25661856 0.25256 0.023097975 0.32573196 -0.043139715 0.25530565 -0.17270242 0.06442319 0.16240332 0.14648464 0.09654196 0.31037596 -0.0539147 0.23819281 -0.15090092 0.048991375 0.11573871 0.078725 0.19393174 0.26017824 -0.053352155 0.23836473 -0.15971972 0.054956935 0.19800682 0.17823274 0.17631978 0.3600948 -0.057391744 0.30457845 -0.19889072 0.05244953 0.090213075 0.17350613 0.044377614 0.29630166 -0.06999667 0.28462386 -0.17194743 0.093742274
|
||||
Stack-10 5 1 1 8 4 1
|
||||
output2 5 1 1 8 4 1
|
||||
0.06387864 0.22883008 0.23308714 0.045865785 0.06820235 0.26621705 0.29714558 0.112830795 0.1669129 0.33512616 0.25788227 0.08388044 0.14331667 0.27875048 0.23716372 0.10920572 0.07898582 0.24287388 0.22543576 0.08901558 0.03376824 0.16912283 0.225415 0.09693983 0.09598104 0.26216167 0.28474298 0.10668853 0.12471523 0.24643728 0.27107987 0.13469991
|
||||
Stack-13 3 1 8 4
|
||||
output3 3 1 8 4
|
||||
-0.16171767 -0.3828573 0.08357508 0.10217983 -0.34800848 -0.3206381 0.03284559 0.15394436 -0.42709222 -0.15115751 -0.0015709695 0.13956246 -0.35903975 -0.14498001 -0.050358675 0.15447712 -0.22225751 -0.21515054 -0.03286325 0.13769037 -0.1488501 -0.29710612 -0.033508375 0.14458355 -0.27084687 -0.31606156 -0.053954814 0.18598628 -0.15771987 -0.15602258 -0.0335121 0.14279547
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#undef __STDC_FORMAT_MACROS
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "include/context.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/version.h"
|
||||
|
@ -115,7 +116,7 @@ int Benchmark::ReadTensorData(std::ifstream &in_file_stream, const std::string &
|
|||
if (this->benchmark_data_.find(tensor_name) != this->benchmark_data_.end()) {
|
||||
return RET_OK;
|
||||
}
|
||||
tensor::MSTensor *tensor = GetTensorByNameOrShape(tensor_name, dims);
|
||||
tensor::MSTensor *tensor = session_->GetOutputByTensorName(tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
|
||||
return RET_ERROR;
|
||||
|
@ -175,17 +176,17 @@ int Benchmark::CompareOutput() {
|
|||
float total_bias = 0;
|
||||
int total_size = 0;
|
||||
for (const auto &calib_tensor : benchmark_data_) {
|
||||
std::string node_or_tensor_name = calib_tensor.first;
|
||||
tensor::MSTensor *tensor = GetTensorByNameOrShape(node_or_tensor_name, calib_tensor.second->shape);
|
||||
std::string tensor_name = calib_tensor.first;
|
||||
tensor::MSTensor *tensor = session_->GetOutputByTensorName(tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << node_or_tensor_name;
|
||||
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int ret;
|
||||
if (tensor->data_type() == kObjectTypeString) {
|
||||
ret = CompareStringData(node_or_tensor_name, tensor);
|
||||
ret = CompareStringData(tensor_name, tensor);
|
||||
} else {
|
||||
ret = CompareDataGetTotalBiasAndSize(node_or_tensor_name, tensor, &total_bias, &total_size);
|
||||
ret = CompareDataGetTotalBiasAndSize(tensor_name, tensor, &total_bias, &total_size);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Error in CompareData";
|
||||
|
@ -212,41 +213,6 @@ int Benchmark::CompareOutput() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
tensor::MSTensor *Benchmark::GetTensorByNodeShape(const std::vector<size_t> &node_shape) {
|
||||
std::vector<tensor::MSTensor *> match_tensors;
|
||||
std::vector<int> shape_vector;
|
||||
(void)std::transform(node_shape.begin(), node_shape.end(), std::back_inserter(shape_vector),
|
||||
[](const size_t &value) { return static_cast<int>(value); });
|
||||
auto tensors = session_->GetOutputs();
|
||||
for (auto &out_tensor_pair : tensors) {
|
||||
if (out_tensor_pair.second->shape() == shape_vector) {
|
||||
match_tensors.emplace_back(out_tensor_pair.second);
|
||||
}
|
||||
}
|
||||
if (match_tensors.empty() || match_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "get tensor by node shape failed";
|
||||
return nullptr;
|
||||
}
|
||||
return match_tensors.front();
|
||||
}
|
||||
|
||||
tensor::MSTensor *Benchmark::GetTensorByNameOrShape(const std::string &node_or_tensor_name,
|
||||
const std::vector<size_t> &dims) {
|
||||
tensor::MSTensor *tensor = nullptr;
|
||||
auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name);
|
||||
if (tensors.empty() || tensors.size() != 1) {
|
||||
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
|
||||
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
|
||||
tensor = session_->GetOutputByTensorName(node_or_tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
return GetTensorByNodeShape(dims);
|
||||
}
|
||||
} else {
|
||||
tensor = tensors.front();
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
int Benchmark::CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias,
|
||||
int *total_size) {
|
||||
float bias = 0;
|
||||
|
|
|
@ -60,10 +60,6 @@ class MS_API Benchmark : public BenchmarkBase {
|
|||
|
||||
int CompareOutput() override;
|
||||
|
||||
tensor::MSTensor *GetTensorByNameOrShape(const std::string &node_or_tensor_name, const std::vector<size_t> &dims);
|
||||
|
||||
tensor::MSTensor *GetTensorByNodeShape(const std::vector<size_t> &node_shape);
|
||||
|
||||
int CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias,
|
||||
int *total_size);
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ int BenchmarkUnifiedApi::ReadTensorData(std::ifstream &in_file_stream, const std
|
|||
if (this->benchmark_data_.find(tensor_name) != this->benchmark_data_.end()) {
|
||||
return RET_OK;
|
||||
}
|
||||
mindspore::MSTensor tensor = GetMSTensorByNameOrShape(tensor_name, dims);
|
||||
mindspore::MSTensor tensor = ms_model_.GetOutputByTensorName(tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
|
||||
return RET_ERROR;
|
||||
|
@ -178,10 +178,10 @@ int BenchmarkUnifiedApi::CompareOutput() {
|
|||
float total_bias = 0;
|
||||
int total_size = 0;
|
||||
for (const auto &calib_tensor : benchmark_data_) {
|
||||
std::string node_or_tensor_name = calib_tensor.first;
|
||||
mindspore::MSTensor tensor = GetMSTensorByNameOrShape(node_or_tensor_name, calib_tensor.second->shape);
|
||||
std::string tensor_name = calib_tensor.first;
|
||||
mindspore::MSTensor tensor = ms_model_.GetOutputByTensorName(tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << node_or_tensor_name;
|
||||
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int ret;
|
||||
|
@ -190,7 +190,7 @@ int BenchmarkUnifiedApi::CompareOutput() {
|
|||
MS_LOG(ERROR) << "Unsupported kObjectTypeString:";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
ret = CompareDataGetTotalBiasAndSize(node_or_tensor_name, &tensor, &total_bias, &total_size);
|
||||
ret = CompareDataGetTotalBiasAndSize(tensor_name, &tensor, &total_bias, &total_size);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Error in CompareData";
|
||||
|
@ -217,36 +217,6 @@ int BenchmarkUnifiedApi::CompareOutput() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
mindspore::MSTensor BenchmarkUnifiedApi::GetMSTensorByNodeShape(const std::vector<size_t> &node_shape) {
|
||||
std::vector<mindspore::MSTensor> match_tensors;
|
||||
std::vector<int64_t> shape_vector = ConverterToInt64Vector<size_t>(node_shape);
|
||||
auto tensors = ms_model_.GetOutputs();
|
||||
for (auto &out_tensor_pair : tensors) {
|
||||
if (out_tensor_pair.Shape() == shape_vector) {
|
||||
match_tensors.emplace_back(out_tensor_pair);
|
||||
}
|
||||
}
|
||||
|
||||
return match_tensors.front();
|
||||
}
|
||||
|
||||
mindspore::MSTensor BenchmarkUnifiedApi::GetMSTensorByNameOrShape(const std::string &node_or_tensor_name,
|
||||
const std::vector<size_t> &dims) {
|
||||
mindspore::MSTensor tensor;
|
||||
auto tensors = ms_model_.GetOutputsByNodeName(node_or_tensor_name);
|
||||
if (tensors.empty() || tensors.size() != 1) {
|
||||
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
|
||||
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
|
||||
tensor = ms_model_.GetOutputByTensorName(node_or_tensor_name);
|
||||
if (tensor == nullptr) {
|
||||
return GetMSTensorByNodeShape(dims);
|
||||
}
|
||||
} else {
|
||||
tensor = tensors.front();
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
int BenchmarkUnifiedApi::CompareDataGetTotalBiasAndSize(const std::string &name, mindspore::MSTensor *tensor,
|
||||
float *total_bias, int *total_size) {
|
||||
float bias = 0;
|
||||
|
|
|
@ -52,8 +52,6 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
|
|||
int CompareDataGetTotalBiasAndSize(const std::string &name, mindspore::MSTensor *tensor, float *total_bias,
|
||||
int *total_size);
|
||||
void InitContext(const std::shared_ptr<mindspore::Context> &context);
|
||||
mindspore::MSTensor GetMSTensorByNodeShape(const std::vector<size_t> &node_shape);
|
||||
mindspore::MSTensor GetMSTensorByNameOrShape(const std::string &node_or_tensor_name, const std::vector<size_t> &dims);
|
||||
|
||||
// call GenerateRandomData to fill inputTensors
|
||||
int GenerateInputData() override;
|
||||
|
|
Loading…
Reference in New Issue