add tflite model verify

This commit is contained in:
mengyuanli 2021-09-29 17:08:29 +08:00
parent e099bb52d5
commit d677efb696
2 changed files with 63 additions and 5 deletions

View File

@ -56,6 +56,54 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
return tflite::UnPackModel(tflite_model_buf_);
}
STATUS TfliteModelParser::TfliteModelVerify() {
if (tflite_model_->subgraphs.empty()) {
MS_LOG(ERROR) << "tflite model does not has a main graph.";
return RET_ERROR;
}
const auto tflite_model_buffers_size = tflite_model_->buffers.size();
const auto tflite_model_operator_codes_size = tflite_model_->operator_codes.size();
for (auto &subgraph : tflite_model_->subgraphs) {
auto all_singraph_tensor_size = subgraph->tensors.size();
if (std::any_of(subgraph->inputs.begin(), subgraph->inputs.end(), [&all_singraph_tensor_size](int32_t index) {
return index >= static_cast<int32_t>(all_singraph_tensor_size) || index < 0;
})) {
MS_LOG(ERROR) << "tflite input illegal.";
return RET_ERROR;
}
if (std::any_of(subgraph->outputs.begin(), subgraph->outputs.end(), [&all_singraph_tensor_size](int32_t index) {
return index >= static_cast<int32_t>(all_singraph_tensor_size) || index < 0;
})) {
MS_LOG(ERROR) << "tflite output illegal.";
return RET_ERROR;
}
for (auto &op : subgraph->operators) {
if (op == nullptr) {
MS_LOG(ERROR) << "tflite contain nullptr op.";
return RET_ERROR;
}
if (op->opcode_index >= tflite_model_operator_codes_size) {
MS_LOG(ERROR) << "op is not a tflite opcode";
return RET_ERROR;
}
}
for (auto &tensor : subgraph->tensors) {
if (tensor == nullptr) {
MS_LOG(ERROR) << "tflite model contain nullptr tensor.";
return RET_ERROR;
}
if (tensor->buffer >= tflite_model_buffers_size) {
MS_LOG(ERROR) << "tflite tensor buffer index beyond upper limit.";
return RET_ERROR;
}
}
}
return RET_OK;
}
api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
// load graph
@ -66,7 +114,14 @@ api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters
return nullptr;
}
auto status = ConvertTfliteGraph();
auto status = TfliteModelVerify();
if (status != RET_OK) {
MS_LOG(ERROR) << "tflite model verify failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertTfliteGraph();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert tflite graph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -46,10 +46,6 @@ class TfliteModelParser : public converter::ModelParser {
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
private:
std::unique_ptr<tflite::ModelT> tflite_model_;
std::map<int, CNodePtr> control_flow_nodes_;
std::map<CNodePtr, std::pair<FuncGraphPtr, FuncGraphPtr>> control_flow_map_;
char *tflite_model_buf_ = nullptr;
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const std::string &model_path);
STATUS ConvertConstTensor(const std::unique_ptr<tflite::TensorT> &tensor, const ParameterPtr &parameter,
const std::string &tensor_name, bool is_uint8_weight_quant);
@ -73,6 +69,13 @@ class TfliteModelParser : public converter::ModelParser {
ops::PrimitiveC *primitive_c);
static STATUS SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
std::vector<QuantParamT> *quant_params, int round_type = 1);
STATUS TfliteModelVerify();
private:
std::unique_ptr<tflite::ModelT> tflite_model_;
std::map<int, CNodePtr> control_flow_nodes_;
std::map<CNodePtr, std::pair<FuncGraphPtr, FuncGraphPtr>> control_flow_map_;
char *tflite_model_buf_ = nullptr;
};
} // namespace lite
} // namespace mindspore