From 3d0039b153943cb744ded1ab5c11d203a4cc93b8 Mon Sep 17 00:00:00 2001 From: lyvette Date: Thu, 13 Aug 2020 21:10:41 +0800 Subject: [PATCH] finxed weight tensor format bug and gather parser bug. --- .../parser/tflite/tflite_arithmetic_parser.cc | 4 ++-- .../parser/tflite/tflite_conv_parser.cc | 4 ++-- .../parser/tflite/tflite_deconv_parser.cc | 2 +- .../tflite/tflite_depthwise_conv_parser.cc | 4 ++-- .../parser/tflite/tflite_fakequant_parser.cc | 4 ++-- .../tflite/tflite_fullyconnected_parser.cc | 4 ++-- .../parser/tflite/tflite_gather_nd_parser.cc | 2 +- .../parser/tflite/tflite_gather_parser.cc | 19 +++++++++++++++++++ .../parser/tflite/tflite_node_parser.cc | 9 ++++++++- .../parser/tflite/tflite_node_parser.h | 3 ++- .../parser/tflite/tflite_transpose_parser.cc | 2 +- 11 files changed, 42 insertions(+), 15 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 8a5a16ccf77..99eeb7f860b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -57,7 +57,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr } if (!x_data->data.empty()) { std::vector x_tensors{x_tensor.get()}; - if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse the first tensor failed"; return RET_ERROR; } @@ -76,7 +76,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr } if (!y_data->data.empty()) { std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse the second tensor failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 3a60f9281ad..f075f10e850 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -59,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -79,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 5856077cb7e..61b2e3baf6b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -59,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { return RET_ERROR; } auto weight_shape = weight_tensor->shape; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index 96215944d1a..d3187aefc01 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -123,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -133,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrinputs[2]; const auto &bias_tensor = tflite_tensors[bias_index]; std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc index 58ca290f364..fa0e90ae11a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc @@ -44,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -58,7 +58,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 9a6341b148c..403dcbbee6a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -45,7 +45,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -59,7 +59,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index 044c3f53718..a954baef012 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -58,7 +58,7 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfl } if (!y_data->data.empty()) { std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse the second tensor failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index c95e3782160..e49f95d5116 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -49,6 +49,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflit attr->batchDims = 0; + auto y_index = tfliteOp->inputs[1]; + const auto &y_tensor = tfliteTensors[y_index]; + if (y_tensor == nullptr) { + MS_LOG(ERROR) << "the second input is null"; + return RET_NULL_PTR; + } + auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); + if (y_data == nullptr) { + MS_LOG(ERROR) << "the data of the second input is null"; + return RET_NULL_PTR; + } + if (!y_data->data.empty()) { + std::vector y_tensors{y_tensor.get()}; + if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { + MS_LOG(ERROR) << "parse the second tensor failed"; + return RET_ERROR; + } + } + op->primitive->value.type = schema::PrimitiveType_Gather; op->primitive->value.value = attr.release(); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc index 45f3cdd2751..cd24097896f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc @@ -45,7 +45,8 @@ STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector &ts, const std::vector> &tfliteModelBuffer, mindspore::lite::TensorCache *tensor_cache, - int node_type) { + int node_type, + bool isWeight) { for (const auto &t : ts) { auto idx = tensor_cache->FindTensor(t->name); if (idx < 0) { @@ -53,6 +54,12 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector &ts, tensor->dataType = GetTfliteDataType(t->type); tensor->dims = t->shape; + if (isWeight) { + tensor->format = schema::Format_KHWC; + } else { + tensor->format = schema::Format_NHWC; + } + if (t->buffer > 0) { CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index 94ae5f8c55c..3a3828e0cd7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -47,7 +47,8 @@ class TfliteNodeParser { STATUS ParseTensor(const std::vector &ts, const std::vector> &tfliteModelBuffer, mindspore::lite::TensorCache *tensor_cache, - int node_type); + int node_type, + bool isWeight); STATUS CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, const tflite::TensorT *tflite_tensor, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index cf11963fba5..5f3d90b8898 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -50,7 +50,7 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tf return RET_ERROR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; }