!18631 [MS][LITE] fix a bug in tensor format transform
Merge pull request !18631 from XianglongZeng/myms_new_3
This commit is contained in:
commit
b53a114e83
|
@ -84,3 +84,4 @@ ml_video_edit_art_generate_20210513.onnx
|
||||||
ml_asr_encoder_int8_202103.onnx
|
ml_asr_encoder_int8_202103.onnx
|
||||||
rpnt_pdr_conv2d_16_fixed_last.onnx
|
rpnt_pdr_conv2d_16_fixed_last.onnx
|
||||||
hdc_efficientnet_b3_1w_class.onnx
|
hdc_efficientnet_b3_1w_class.onnx
|
||||||
|
yolov5s.onnx
|
||||||
|
|
|
@ -93,3 +93,4 @@ ml_asr_encoder_int8_202103.onnx 2.1
|
||||||
# The input range of hdc_efficientnet_b3_1w_class.onnx is [-5, 5], the computation of middle layers contains small
|
# The input range of hdc_efficientnet_b3_1w_class.onnx is [-5, 5], the computation of middle layers contains small
|
||||||
# values(<1e-5), The fp16 computation precision is low in this case.
|
# values(<1e-5), The fp16 computation precision is low in this case.
|
||||||
hdc_efficientnet_b3_1w_class.onnx 18
|
hdc_efficientnet_b3_1w_class.onnx 18
|
||||||
|
yolov5s.onnx 2
|
||||||
|
|
|
@ -150,6 +150,10 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
|
||||||
lite::DataInfo data_info;
|
lite::DataInfo data_info;
|
||||||
int status;
|
int status;
|
||||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||||
|
auto input_node = cnode->input(index)->cast<ParameterPtr>();
|
||||||
|
if (!input_node->has_default()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
|
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
|
||||||
} else {
|
} else {
|
||||||
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
|
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
|
||||||
|
@ -161,7 +165,7 @@ void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, s
|
||||||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
|
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<int> new_shape;
|
std::vector<int> new_shape = data_info.shape_;
|
||||||
if (data_info.shape_.size() == 1) {
|
if (data_info.shape_.size() == 1) {
|
||||||
new_shape = {1, 1, 1, data_info.shape_[0]};
|
new_shape = {1, 1, 1, data_info.shape_[0]};
|
||||||
} else if (data_info.shape_.size() == 2) {
|
} else if (data_info.shape_.size() == 2) {
|
||||||
|
|
Loading…
Reference in New Issue