!4416 fixed weight tensor format bug and gather parser bug

Merge pull request !4416 from lyvette/master
This commit is contained in:
mindspore-ci-bot 2020-08-14 09:24:20 +08:00 committed by Gitee
commit 0ff1000b24
11 changed files with 42 additions and 15 deletions

View File

@ -57,7 +57,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
}
if (!x_data->data.empty()) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
@ -76,7 +76,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
}
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}

View File

@ -59,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
@ -79,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}

View File

@ -59,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape;

View File

@ -123,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
@ -133,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
auto bias_index = tflite_op->inputs[2];
const auto &bias_tensor = tflite_tensors[bias_index];
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}

View File

@ -44,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
@ -58,7 +58,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}

View File

@ -52,7 +52,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
@ -66,7 +66,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}

View File

@ -58,7 +58,7 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl
}
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}

View File

@ -49,6 +49,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->batchDims = 0;
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (!y_data->data.empty()) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.value = attr.release();
return RET_OK;

View File

@ -45,7 +45,8 @@ STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector<std::unique_ptr<
STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
mindspore::lite::TensorCache *tensor_cache,
int node_type) {
int node_type,
bool isWeight) {
for (const auto &t : ts) {
auto idx = tensor_cache->FindTensor(t->name);
if (idx < 0) {
@ -53,6 +54,12 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts,
tensor->dataType = GetTfliteDataType(t->type);
tensor->dims = t->shape;
if (isWeight) {
tensor->format = schema::Format_KHWC;
} else {
tensor->format = schema::Format_NHWC;
}
if (t->buffer > 0) {
CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get());
}

View File

@ -47,7 +47,8 @@ class TfliteNodeParser {
STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
mindspore::lite::TensorCache *tensor_cache,
int node_type);
int node_type,
bool isWeight);
STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const tflite::TensorT *tflite_tensor,

View File

@ -50,7 +50,7 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
return RET_ERROR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}