!4416 fixed weight tensor format bug and gather parser bug
Merge pull request !4416 from lyvette/master
This commit is contained in:
commit
0ff1000b24
|
@ -57,7 +57,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
}
|
||||
if (!x_data->data.empty()) {
|
||||
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
|
||||
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse the first tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
}
|
||||
if (!y_data->data.empty()) {
|
||||
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
|
||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse the second tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
|
|
|
@ -123,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
auto bias_index = tflite_op->inputs[2];
|
||||
const auto &bias_tensor = tflite_tensors[bias_index];
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse bias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl
|
|||
}
|
||||
if (!y_data->data.empty()) {
|
||||
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
|
||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse the second tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -49,6 +49,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
|
||||
attr->batchDims = 0;
|
||||
|
||||
auto y_index = tfliteOp->inputs[1];
|
||||
const auto &y_tensor = tfliteTensors[y_index];
|
||||
if (y_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the second input is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
|
||||
if (y_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the data of the second input is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (!y_data->data.empty()) {
|
||||
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
|
||||
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse the second tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Gather;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
|
|
|
@ -45,7 +45,8 @@ STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector<std::unique_ptr<
|
|||
STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
mindspore::lite::TensorCache *tensor_cache,
|
||||
int node_type) {
|
||||
int node_type,
|
||||
bool isWeight) {
|
||||
for (const auto &t : ts) {
|
||||
auto idx = tensor_cache->FindTensor(t->name);
|
||||
if (idx < 0) {
|
||||
|
@ -53,6 +54,12 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts,
|
|||
tensor->dataType = GetTfliteDataType(t->type);
|
||||
tensor->dims = t->shape;
|
||||
|
||||
if (isWeight) {
|
||||
tensor->format = schema::Format_KHWC;
|
||||
} else {
|
||||
tensor->format = schema::Format_NHWC;
|
||||
}
|
||||
|
||||
if (t->buffer > 0) {
|
||||
CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get());
|
||||
}
|
||||
|
|
|
@ -47,7 +47,8 @@ class TfliteNodeParser {
|
|||
STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
mindspore::lite::TensorCache *tensor_cache,
|
||||
int node_type);
|
||||
int node_type,
|
||||
bool isWeight);
|
||||
|
||||
STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
|
||||
const tflite::TensorT *tflite_tensor,
|
||||
|
|
|
@ -50,7 +50,7 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
|
|||
return RET_ERROR;
|
||||
}
|
||||
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
|
||||
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) {
|
||||
MS_LOG(ERROR) << "parse weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue