[MSLITE] Fix bug of tf model‘s output names.

This commit is contained in:
wang_shaocong 2021-08-18 11:30:57 +08:00
parent db19c60581
commit f0828412bb
12 changed files with 63 additions and 134 deletions

View File

@ -167,8 +167,7 @@ int NPUInsertTransformPass::InsertNode(NPUOp *op, NPUOp *post_op, size_t post_in
} else {
// post_op nullptr mean output, we remain graph output tensor name unchanged
auto graph_output_name = in_tensor.Name();
in_tensor.SetTensorName(graph_output_name + "_before_" + name_);
nc2nh_tensor->SetTensorName(graph_output_name);
nc2nh_tensor->SetTensorName(graph_output_name + "_after_" + name_);
}
return RET_OK;
}

View File

@ -16,5 +16,5 @@ deconvs_model
# onnx
ml_2012_ocr_cn.onnx
# aware_training
Modify_Out_video_infer2.tflite;;;;aware_training
video_infer2.tflite;;;;aware_training
mobilenet_v1_1.0_224_quant.tflite;;;;aware_training

View File

@ -1,6 +1,6 @@
mobilenet_v1_1.0_224.tflite
mobilenet_v2_1.0_224.tflite
Modify_Out_mtk_age_gender_fp16.tflite
mtk_age_gender_fp16.tflite
mtk_isface.tflite
mtk_landmark.tflite
mtk_new_detect.tflite
@ -25,7 +25,7 @@ Q_face_recognition.onnx
Q888_iris_detect.onnx
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite
Q_iMaxSR_RGB_385_p_pb2tflite.tflite
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite
Q_detect_fpn_add_inception-1448650.tflite
Q888_face_dress_mv3y.tflite
Q888_HADB_AADB_MBV2_model_fp32.tflite
Q888_landmark.tflite
@ -58,4 +58,4 @@ mtk_detect_mbv1_640_480_nopostprocess_simplified
mtk_model_normalize_object_scene_ps_20200519_f16.tflite
mtk_AADB_HADB_MBV2_model_f16.tflite
mtk_model_emotions_0725_fp16.tflite
Modify_Out_Q888_age_gender_orderd.tflite
Q888_age_gender_orderd.tflite

View File

@ -6,10 +6,10 @@ resnet.tflite
squeezenet.tflite
mtk_AADB_HADB_MBV2_model_fp32.tflite
hiai_cn_recognize_modify_padv2.tflite
Modify_Out_hiai_cv_focusShootOCRModel_08.tflite
hiai_cv_focusShootOCRModel_08.tflite
hiai_model_normalize_object_scene_ps_20200519.tflite
inception_v3.tflite
Modify_Out_mtk_age_gender_fp16.tflite
mtk_age_gender_fp16.tflite
mtk_isface.tflite
mtk_landmark.tflite
mtk_new_detect.tflite
@ -102,7 +102,7 @@ Q_crnn_ori_v2_405001_notrans_nopre.pb
bolt_segment.pb
ml_location_lane_counter.onnx 2
gts_detect_5k_tf115.tflite
Modify_Out_smartreply.tflite
smartreply.tflite
ml_text_correction.tflite
ml_ocr_jk_pb2tflite.tflite
scan_hms_angle_pb2tflite.tflite
@ -111,21 +111,21 @@ ml_face_openclose_tflite.tflite
unet_mbv2_05_104pts.tflite
hiai_AADB_HADB_MBV2_model_f16.tflite
hiai_AADB_HADB_MBV2_model_fp32.tflite
Modify_Out_hiai_detect_curve_model_float32.tflite
Modify_Out_hiai_detectmodel_06_23_960_480_1180700.tflite
Modify_Out_lite-model_aiy_vision_classifier_food_V1_1.tflite
hiai_detect_curve_model_float32.tflite
hiai_detectmodel_06_23_960_480_1180700.tflite
lite-model_aiy_vision_classifier_food_V1_1.tflite
lite-model_disease-classification_1.tflite
lite-model_models_mushroom-identification_v1_1.tflite
Modify_Out_smartreply_1_default_1.tflite
smartreply_1_default_1.tflite
text_classification.tflite
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite
Modify_Out_Q_hand_0812_pb2tflite.tflite
Q_detect_fpn_add_inception-1448650.tflite
Q_hand_0812_pb2tflite.tflite
bloom_landmark.tflite
Q888_face_dress_mv3y.tflite
Q888_HADB_AADB_MBV2_model_fp32.tflite
Q888_landmark.tflite
Q888_pose.tflite
Modify_Out_Q888_lapa158_unet_0924.tflite
Q888_lapa158_unet_0924.tflite
Q888_isface.tflite
Q888_new_detect.tflite
Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
@ -156,7 +156,7 @@ hiai_model_0909_kd_rot_ps_softmax.tflite
hiai_chinese_english_recognize_model_float32.tflite
hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
Modify_Out_hiai_detectmodel_desnet_256_128_64_32.tflite
hiai_detectmodel_desnet_256_128_64_32.tflite
mtk_AADB_HADB_MBV3_model_fp32.tflite
Q888_face_recognition.onnx
mobilenet_v1_0.25_128.tflite
@ -201,7 +201,7 @@ mtk_AADB_HADB_MBV2_model_f16.tflite
mtk_AADB_HADB_MBV3_model_f16.tflite
mtk_model_emotions_0725_fp16.tflite
mtk_face_features_v1_fp16.tflite
Modify_Out_Q888_age_gender_orderd.tflite
Q888_age_gender_orderd.tflite
emotion
gender_res_large_deploy
glasses

View File

@ -38,7 +38,7 @@ multi_person_mobilenet_v1_075_float.tflite
ide_label_base.tflite
ide_label_retrained.tflite
ml_ei_headpose.tflite
Modify_Out_ml_ei_landmark.tflite
ml_ei_landmark.tflite
mnist.tflite
mobilenet.tflite
resnet.tflite
@ -58,7 +58,7 @@ hiai_cv_focusShootOCRModel_02.tflite
hiai_cv_poseEstimation.tflite
inception_v4.tflite
mtk_model_normalize_object_scene_ps_20200519_f16.tflite
Modify_Out_mtk_age_gender_fp16.tflite
mtk_age_gender_fp16.tflite
mtk_model_face_dress_fp16.tflite
mtk_AADB_HADB_MBV2_model_f16.tflite
mtk_AADB_HADB_MBV3_model_f16.tflite
@ -77,7 +77,7 @@ hiai_cpu_face_emotion.tflite
hiai_cpu_face_gazing.tflite
hiai_cpu_face_headpose.tflite
hiai_humanDetection.tflite
Modify_Out_hiai_cv_focusShootOCRModel_08.tflite
hiai_cv_focusShootOCRModel_08.tflite
ml_face_openclose.tflite
hiai_face_model_npu.tflite
hiai_ctpn_feature_map.tflite
@ -122,13 +122,13 @@ mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
mtk_276landmark_0913.tflite
mtk_face_recognition.tflite
mtk_convert_model.tflite
Modify_Out_smartreply.tflite
smartreply.tflite
mindspore_text_classification_tflite.tflite
# ml_location.tflite
ml_text_correction.tflite
ml_pic_shopping.tflite
ml_vision_guide_detection3_pb2tflite.tflite
Modify_Out_ml_vision_guide_detection1_pb2tflite.tflite
ml_vision_guide_detection1_pb2tflite.tflite
ml_pic_shopping_pb2tflite.tflite
ml_ocr_jk_pb2tflite.tflite
ml_ocr_latin_pb2tflite.tflite
@ -152,27 +152,27 @@ Q_language_model_hrmini_Q4_b4_17w.tflite
Q_new_detect.tflite
Q_object_scene.tflite
Q_pose.tflite
Modify_Out_ml_ei_landmark_pb2tflite.tflite
ml_ei_landmark_pb2tflite.tflite
unet_mbv2_05_104pts.tflite
hiai_AADB_HADB_MBV2_model_f16.tflite
hiai_AADB_HADB_MBV2_model_fp32.tflite
Modify_Out_hiai_detect_curve_model_float32.tflite
Modify_Out_hiai_detectmodel_06_23_960_480_1180700.tflite
Modify_Out_hiai_detectmodel_desnet_256_128_64_32.tflite
Modify_Out_lite-model_aiy_vision_classifier_food_V1_1.tflite
hiai_detect_curve_model_float32.tflite
hiai_detectmodel_06_23_960_480_1180700.tflite
hiai_detectmodel_desnet_256_128_64_32.tflite
lite-model_aiy_vision_classifier_food_V1_1.tflite
lite-model_disease-classification_1.tflite
lite-model_models_mushroom-identification_v1_1.tflite
Modify_Out_smartreply_1_default_1.tflite
smartreply_1_default_1.tflite
text_classification.tflite
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite
Modify_Out_Q_hand_0812_pb2tflite.tflite
Q_detect_fpn_add_inception-1448650.tflite
Q_hand_0812_pb2tflite.tflite
bloom_landmark.tflite
Modify_Out_Q888_age_gender_orderd.tflite
Q888_age_gender_orderd.tflite
Q888_face_dress_mv3y.tflite
Q888_HADB_AADB_MBV2_model_fp32.tflite
Q888_landmark.tflite
Q888_pose.tflite
Modify_Out_Q888_lapa158_unet_0924.tflite
Q888_lapa158_unet_0924.tflite
Q888_isface.tflite
Q888_new_detect.tflite
Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite
@ -180,7 +180,7 @@ Q888_face_emo_dress_mv3_orderd.tflite
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite
Q_iMaxSR_RGB_385_p_pb2tflite.tflite
bloom_new_detect.tflite
Modify_Out_bloom_model_age_gender.tflite
bloom_model_age_gender.tflite
bloom_isface.tflite
hiai_object_detect_814.tflite
hiai_object_tflite_graph_8bit.tflite
@ -193,10 +193,10 @@ ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2:2,1
ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2:2,1
hdc_tb_cn_neg.tflite;3:3,1,2 0.5
hiai_cv_labelDetectorModel_v3.tflite;2:2,1
Modify_Out_ml_tacotron_decoder_step_stf.tflite;9;1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1
ml_tacotron_decoder_step_stf.tflite;9;1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1
ml_headpose_pb2tflite.tflite;3:2,3,1;1,64,64,3:16:16
ml_ei_headpose_pb2tflite.tflite;3:2,3,1;1,64,64,3:16:16
lite-model_albert_lite_base_squadv1_metadata_1.tflite;3:2,3,1
lite-model_mobilebert_1_metadata_1.tflite;3
Modify_Out_hiai_vad.tflite;2
hiai_vad.tflite;2
add_uint8.tflite;2

View File

@ -1,4 +1,4 @@
Modify_Out_video_infer2.tflite
video_infer2.tflite
mobilenet_v1_0.25_128_quant.tflite
mobilenet_v1_0.25_160_quant.tflite
mobilenet_v1_0.25_192_quant.tflite

View File

@ -49,7 +49,7 @@ ide_label_base.tflite 22
# dividing 0 in the following operator.
#ide_label_retrained.tflite
ml_ei_headpose.tflite 3
Modify_Out_ml_ei_landmark.tflite 3
ml_ei_landmark.tflite 3
mnist.tflite 4
mobilenet.tflite 0.1
resnet.tflite 120
@ -131,7 +131,7 @@ mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 22
mtk_276landmark_0913.tflite 16
mtk_face_recognition.tflite 8
mtk_convert_model.tflite 5
Modify_Out_smartreply.tflite 0.1
smartreply.tflite 0.1
mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias
#ml_location.tflite 0.1
ml_text_correction.tflite 1
@ -141,7 +141,7 @@ ml_text_correction.tflite 1
# fp16: 27.6 - 27.4 = 0.2
#ml_pic_shopping.tflite 0.1
ml_vision_guide_detection3_pb2tflite.tflite 0.5
Modify_Out_ml_vision_guide_detection1_pb2tflite.tflite 0.5
ml_vision_guide_detection1_pb2tflite.tflite 0.5
ml_pic_shopping_pb2tflite.tflite 95
ml_ocr_jk_pb2tflite.tflite 0.5
ml_ocr_latin_pb2tflite.tflite 11.5
@ -158,17 +158,17 @@ lite-model_on_device_vision_classifier_landmarks_classifier_asia_V1_1.tflite 25
lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V1_1.tflite 11
lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite 32
lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite 14
Modify_Out_ml_ei_landmark_pb2tflite.tflite 2
ml_ei_landmark_pb2tflite.tflite 2
unet_mbv2_05_104pts.tflite 17
hiai_AADB_HADB_MBV2_model_f16.tflite 3.5
hiai_AADB_HADB_MBV2_model_fp32.tflite 4.5
Modify_Out_mtk_age_gender_fp16.tflite 26
Modify_Out_hiai_detect_curve_model_float32.tflite 9
mtk_age_gender_fp16.tflite 26
hiai_detect_curve_model_float32.tflite 9
Q_language_model_hrmini_Q4_b4_17w.tflite 3.5
Modify_Out_lite-model_aiy_vision_classifier_food_V1_1.tflite 47.5
lite-model_aiy_vision_classifier_food_V1_1.tflite 47.5
lite-model_disease-classification_1.tflite 70
lite-model_models_mushroom-identification_v1_1.tflite 5
Modify_Out_smartreply_1_default_1.tflite 0.5
smartreply_1_default_1.tflite 0.5
text_classification.tflite 0.5
Q_AADB_HADB_MBV2_model.tflite 5
# the input of Q_convert model is between 0-255
@ -190,16 +190,16 @@ Q_new_detect.tflite 3.5
# the input of Q_object_scene model is between 0-255
Q_object_scene.tflite 3
Q_pose.tflite 4.1
Modify_Out_Q_detect_fpn_add_inception-1448650.tflite 1
Q_detect_fpn_add_inception-1448650.tflite 1
bloom_landmark.tflite 0.5
# input data: 0~255
Modify_Out_Q888_age_gender_orderd.tflite 1.5
Q888_age_gender_orderd.tflite 1.5
Q888_face_dress_mv3y.tflite 0.5
Q888_HADB_AADB_MBV2_model_fp32.tflite 2.5
Q888_landmark.tflite 0.5
Q888_pose.tflite 6.1
# the output contains value less than e-7
Modify_Out_Q888_lapa158_unet_0924.tflite 19
Q888_lapa158_unet_0924.tflite 19
Q888_isface.tflite 1.0
Q888_new_detect.tflite 1.5
Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2
@ -208,7 +208,7 @@ Q888_face_emo_dress_mv3_orderd.tflite 2.5
Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1
Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5
bloom_new_detect.tflite 3.5
Modify_Out_bloom_model_age_gender.tflite 0.5
bloom_model_age_gender.tflite 0.5
bloom_isface.tflite 0.5
# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In
# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision.

View File

@ -1,6 +1,6 @@
Stack-8 5 1 1 8 4 2
output 5 1 1 8 4 2
0.115831 0.11307496 0.24593274 0.34630755 -0.156871 0.21111916 -0.1046219 0.01590158 0.2745127 0.17317073 0.1787783 0.36557162 -0.13658395 0.2911819 -0.17356569 0.06825469 0.30655888 0.29681587 0.0078597255 0.3846875 -0.09266291 0.26170188 -0.15063931 0.04322962 0.25661856 0.25256 0.023097975 0.32573196 -0.043139715 0.25530565 -0.17270242 0.06442319 0.16240332 0.14648464 0.09654196 0.31037596 -0.0539147 0.23819281 -0.15090092 0.048991375 0.11573871 0.078725 0.19393174 0.26017824 -0.053352155 0.23836473 -0.15971972 0.054956935 0.19800682 0.17823274 0.17631978 0.3600948 -0.057391744 0.30457845 -0.19889072 0.05244953 0.090213075 0.17350613 0.044377614 0.29630166 -0.06999667 0.28462386 -0.17194743 0.093742274
Stack-10 5 1 1 8 4 1
output2 5 1 1 8 4 1
0.06387864 0.22883008 0.23308714 0.045865785 0.06820235 0.26621705 0.29714558 0.112830795 0.1669129 0.33512616 0.25788227 0.08388044 0.14331667 0.27875048 0.23716372 0.10920572 0.07898582 0.24287388 0.22543576 0.08901558 0.03376824 0.16912283 0.225415 0.09693983 0.09598104 0.26216167 0.28474298 0.10668853 0.12471523 0.24643728 0.27107987 0.13469991
Stack-13 3 1 8 4
output3 3 1 8 4
-0.16171767 -0.3828573 0.08357508 0.10217983 -0.34800848 -0.3206381 0.03284559 0.15394436 -0.42709222 -0.15115751 -0.0015709695 0.13956246 -0.35903975 -0.14498001 -0.050358675 0.15447712 -0.22225751 -0.21515054 -0.03286325 0.13769037 -0.1488501 -0.29710612 -0.033508375 0.14458355 -0.27084687 -0.31606156 -0.053954814 0.18598628 -0.15771987 -0.15602258 -0.0335121 0.14279547

View File

@ -20,6 +20,7 @@
#undef __STDC_FORMAT_MACROS
#include <utility>
#include <functional>
#include <algorithm>
#include "include/context.h"
#include "include/ms_tensor.h"
#include "include/version.h"
@ -115,7 +116,7 @@ int Benchmark::ReadTensorData(std::ifstream &in_file_stream, const std::string &
if (this->benchmark_data_.find(tensor_name) != this->benchmark_data_.end()) {
return RET_OK;
}
tensor::MSTensor *tensor = GetTensorByNameOrShape(tensor_name, dims);
tensor::MSTensor *tensor = session_->GetOutputByTensorName(tensor_name);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
return RET_ERROR;
@ -175,17 +176,17 @@ int Benchmark::CompareOutput() {
float total_bias = 0;
int total_size = 0;
for (const auto &calib_tensor : benchmark_data_) {
std::string node_or_tensor_name = calib_tensor.first;
tensor::MSTensor *tensor = GetTensorByNameOrShape(node_or_tensor_name, calib_tensor.second->shape);
std::string tensor_name = calib_tensor.first;
tensor::MSTensor *tensor = session_->GetOutputByTensorName(tensor_name);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << node_or_tensor_name;
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
return RET_ERROR;
}
int ret;
if (tensor->data_type() == kObjectTypeString) {
ret = CompareStringData(node_or_tensor_name, tensor);
ret = CompareStringData(tensor_name, tensor);
} else {
ret = CompareDataGetTotalBiasAndSize(node_or_tensor_name, tensor, &total_bias, &total_size);
ret = CompareDataGetTotalBiasAndSize(tensor_name, tensor, &total_bias, &total_size);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Error in CompareData";
@ -212,41 +213,6 @@ int Benchmark::CompareOutput() {
return RET_OK;
}
tensor::MSTensor *Benchmark::GetTensorByNodeShape(const std::vector<size_t> &node_shape) {
std::vector<tensor::MSTensor *> match_tensors;
std::vector<int> shape_vector;
(void)std::transform(node_shape.begin(), node_shape.end(), std::back_inserter(shape_vector),
[](const size_t &value) { return static_cast<int>(value); });
auto tensors = session_->GetOutputs();
for (auto &out_tensor_pair : tensors) {
if (out_tensor_pair.second->shape() == shape_vector) {
match_tensors.emplace_back(out_tensor_pair.second);
}
}
if (match_tensors.empty() || match_tensors.size() != 1) {
MS_LOG(ERROR) << "get tensor by node shape failed";
return nullptr;
}
return match_tensors.front();
}
tensor::MSTensor *Benchmark::GetTensorByNameOrShape(const std::string &node_or_tensor_name,
const std::vector<size_t> &dims) {
tensor::MSTensor *tensor = nullptr;
auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name);
if (tensors.empty() || tensors.size() != 1) {
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
tensor = session_->GetOutputByTensorName(node_or_tensor_name);
if (tensor == nullptr) {
return GetTensorByNodeShape(dims);
}
} else {
tensor = tensors.front();
}
return tensor;
}
int Benchmark::CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias,
int *total_size) {
float bias = 0;

View File

@ -60,10 +60,6 @@ class MS_API Benchmark : public BenchmarkBase {
int CompareOutput() override;
tensor::MSTensor *GetTensorByNameOrShape(const std::string &node_or_tensor_name, const std::vector<size_t> &dims);
tensor::MSTensor *GetTensorByNodeShape(const std::vector<size_t> &node_shape);
int CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias,
int *total_size);

View File

@ -120,7 +120,7 @@ int BenchmarkUnifiedApi::ReadTensorData(std::ifstream &in_file_stream, const std
if (this->benchmark_data_.find(tensor_name) != this->benchmark_data_.end()) {
return RET_OK;
}
mindspore::MSTensor tensor = GetMSTensorByNameOrShape(tensor_name, dims);
mindspore::MSTensor tensor = ms_model_.GetOutputByTensorName(tensor_name);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
return RET_ERROR;
@ -178,10 +178,10 @@ int BenchmarkUnifiedApi::CompareOutput() {
float total_bias = 0;
int total_size = 0;
for (const auto &calib_tensor : benchmark_data_) {
std::string node_or_tensor_name = calib_tensor.first;
mindspore::MSTensor tensor = GetMSTensorByNameOrShape(node_or_tensor_name, calib_tensor.second->shape);
std::string tensor_name = calib_tensor.first;
mindspore::MSTensor tensor = ms_model_.GetOutputByTensorName(tensor_name);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << node_or_tensor_name;
MS_LOG(ERROR) << "Get tensor failed, tensor name: " << tensor_name;
return RET_ERROR;
}
int ret;
@ -190,7 +190,7 @@ int BenchmarkUnifiedApi::CompareOutput() {
MS_LOG(ERROR) << "Unsupported kObjectTypeString:";
return RET_ERROR;
} else {
ret = CompareDataGetTotalBiasAndSize(node_or_tensor_name, &tensor, &total_bias, &total_size);
ret = CompareDataGetTotalBiasAndSize(tensor_name, &tensor, &total_bias, &total_size);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Error in CompareData";
@ -217,36 +217,6 @@ int BenchmarkUnifiedApi::CompareOutput() {
return RET_OK;
}
mindspore::MSTensor BenchmarkUnifiedApi::GetMSTensorByNodeShape(const std::vector<size_t> &node_shape) {
std::vector<mindspore::MSTensor> match_tensors;
std::vector<int64_t> shape_vector = ConverterToInt64Vector<size_t>(node_shape);
auto tensors = ms_model_.GetOutputs();
for (auto &out_tensor_pair : tensors) {
if (out_tensor_pair.Shape() == shape_vector) {
match_tensors.emplace_back(out_tensor_pair);
}
}
return match_tensors.front();
}
mindspore::MSTensor BenchmarkUnifiedApi::GetMSTensorByNameOrShape(const std::string &node_or_tensor_name,
const std::vector<size_t> &dims) {
mindspore::MSTensor tensor;
auto tensors = ms_model_.GetOutputsByNodeName(node_or_tensor_name);
if (tensors.empty() || tensors.size() != 1) {
MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name
<< " or node has more than one output tensor, switch to GetOutputByTensorName";
tensor = ms_model_.GetOutputByTensorName(node_or_tensor_name);
if (tensor == nullptr) {
return GetMSTensorByNodeShape(dims);
}
} else {
tensor = tensors.front();
}
return tensor;
}
int BenchmarkUnifiedApi::CompareDataGetTotalBiasAndSize(const std::string &name, mindspore::MSTensor *tensor,
float *total_bias, int *total_size) {
float bias = 0;

View File

@ -52,8 +52,6 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
int CompareDataGetTotalBiasAndSize(const std::string &name, mindspore::MSTensor *tensor, float *total_bias,
int *total_size);
void InitContext(const std::shared_ptr<mindspore::Context> &context);
mindspore::MSTensor GetMSTensorByNodeShape(const std::vector<size_t> &node_shape);
mindspore::MSTensor GetMSTensorByNameOrShape(const std::string &node_or_tensor_name, const std::vector<size_t> &dims);
// call GenerateRandomData to fill inputTensors
int GenerateInputData() override;