!18631 [MS][LITE] fix a bug in tensor format transform

Merge pull request !18631 from XianglongZeng/myms_new_3
This commit is contained in:
i-robot 2021-06-22 03:02:08 +00:00 committed by Gitee
commit b53a114e83
3 changed files with 7 additions and 1 deletions

View File

@ -84,3 +84,4 @@ ml_video_edit_art_generate_20210513.onnx
ml_asr_encoder_int8_202103.onnx ml_asr_encoder_int8_202103.onnx
rpnt_pdr_conv2d_16_fixed_last.onnx rpnt_pdr_conv2d_16_fixed_last.onnx
hdc_efficientnet_b3_1w_class.onnx hdc_efficientnet_b3_1w_class.onnx
yolov5s.onnx

View File

@ -93,3 +93,4 @@ ml_asr_encoder_int8_202103.onnx 2.1
# The input range of hdc_efficientnet_b3_1w_class.onnx is [-5, 5], the computation of middle layers contains small # The input range of hdc_efficientnet_b3_1w_class.onnx is [-5, 5], the computation of middle layers contains small
# values(<1e-5), The fp16 computation precision is low in this case. # values(<1e-5), The fp16 computation precision is low in this case.
hdc_efficientnet_b3_1w_class.onnx 18 hdc_efficientnet_b3_1w_class.onnx 18
yolov5s.onnx 2

View File

@ -150,6 +150,10 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
lite::DataInfo data_info; lite::DataInfo data_info;
int status; int status;
if (utils::isa<ParameterPtr>(cnode->input(index))) { if (utils::isa<ParameterPtr>(cnode->input(index))) {
auto input_node = cnode->input(index)->cast<ParameterPtr>();
if (!input_node->has_default()) {
return;
}
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info); status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
} else { } else {
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info); status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
@ -161,7 +165,7 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) { (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
return; return;
} }
std::vector<int> new_shape; std::vector<int> new_shape = data_info.shape_;
if (data_info.shape_.size() == 1) { if (data_info.shape_.size() == 1) {
new_shape = {1, 1, 1, data_info.shape_[0]}; new_shape = {1, 1, 1, data_info.shape_[0]};
} else if (data_info.shape_.size() == 2) { } else if (data_info.shape_.size() == 2) {