diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 8a5a16ccf7..99eeb7f860 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -57,7 +57,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr } if (!x_data->data.empty()) { std::vector x_tensors{x_tensor.get()}; - if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse the first tensor failed"; return RET_ERROR; } @@ -76,7 +76,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr } if (!y_data->data.empty()) { std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse the second tensor failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 3a60f9281a..f075f10e85 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -59,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -79,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 5856077cb7..61b2e3baf6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -59,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { return RET_ERROR; } auto weight_shape = weight_tensor->shape; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index 96215944d1..d3187aefc0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -123,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -133,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrinputs[2]; const auto &bias_tensor = tflite_tensors[bias_index]; std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc index 58ca290f36..fa0e90ae11 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc @@ -44,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -58,7 +58,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 7f039d0367..0084ad8c7e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -52,7 +52,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; } @@ -66,7 +66,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse bias failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index 044c3f5371..a954baef01 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -58,7 +58,7 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfl } if (!y_data->data.empty()) { std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse the second tensor failed"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index c95e378216..e49f95d511 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -49,6 +49,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflit attr->batchDims = 0; + auto y_index = tfliteOp->inputs[1]; + const auto &y_tensor = tfliteTensors[y_index]; + if (y_tensor == nullptr) { + MS_LOG(ERROR) << "the second input is null"; + return RET_NULL_PTR; + } + auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); + if (y_data == nullptr) { + MS_LOG(ERROR) << "the data of the second input is null"; + return RET_NULL_PTR; + } + if (!y_data->data.empty()) { + std::vector y_tensors{y_tensor.get()}; + if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { + MS_LOG(ERROR) << "parse the second tensor failed"; + return RET_ERROR; + } + } + op->primitive->value.type = schema::PrimitiveType_Gather; op->primitive->value.value = attr.release(); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc index 45f3cdd275..cd24097896 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc @@ -45,7 +45,8 @@ STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector &ts, const std::vector> &tfliteModelBuffer, mindspore::lite::TensorCache *tensor_cache, - int node_type) { + int node_type, + bool isWeight) { for (const auto &t : ts) { auto idx = tensor_cache->FindTensor(t->name); if (idx < 0) { @@ -53,6 +54,12 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector &ts, tensor->dataType = GetTfliteDataType(t->type); tensor->dims = t->shape; + if (isWeight) { + tensor->format = schema::Format_KHWC; + } else { + tensor->format = schema::Format_NHWC; + } + if (t->buffer > 0) { CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index 94ae5f8c55..3a3828e0cd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -47,7 +47,8 @@ class TfliteNodeParser { STATUS ParseTensor(const std::vector &ts, const std::vector> &tfliteModelBuffer, mindspore::lite::TensorCache *tensor_cache, - int node_type); + int node_type, + bool isWeight); STATUS CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, const tflite::TensorT *tflite_tensor, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index cf11963fba..5f3d90b889 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -50,7 +50,7 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tf return RET_ERROR; } std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { MS_LOG(ERROR) << "parse weight failed"; return RET_ERROR; }