fix reshape and l2norm tflite parser bug

This commit is contained in:
lyvette 2020-08-29 11:55:51 +08:00
parent 81004f5e90
commit 34361e2fe4
3 changed files with 10 additions and 8 deletions

View File

@ -51,7 +51,6 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU;
} else if (std::strcmp(node_name, "Relu6") == 0) {
MS_LOG(DEBUG) << "parse TfliteRelu6Parser";
attr->type = schema::ActivationType_RELU6;

View File

@ -51,10 +51,6 @@ STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
auto data_index = tflite_op->inputs[0];
if (static_cast<int>(tflite_op->inputs.size()) <= data_index) {
MS_LOG(ERROR) << "the size of input should be greater than " << data_index;
return RET_ERROR;
}
const auto &data_tensor = tflite_tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the input tensor is null";

View File

@ -57,9 +57,16 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli
MS_LOG(ERROR) << "shape_tensor is null";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) {
MS_LOG(ERROR) << "get reshape -> shape failed";
return RET_ERROR;
auto &buf_data = tflite_model_buffer[shape_tensor->buffer];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "buf_data is null";
return RET_NULL_PTR;
}
if (!buf_data->data.empty()) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) {
MS_LOG(ERROR) << "get reshape -> shape failed";
return RET_ERROR;
}
}
} else {
attr->format = schema::Format_NHWC;