forked from mindspore-Ecosystem/mindspore
add tflite model verify
This commit is contained in:
parent
e099bb52d5
commit
d677efb696
|
@ -56,6 +56,54 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
|
|||
return tflite::UnPackModel(tflite_model_buf_);
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::TfliteModelVerify() {
|
||||
if (tflite_model_->subgraphs.empty()) {
|
||||
MS_LOG(ERROR) << "tflite model does not has a main graph.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto tflite_model_buffers_size = tflite_model_->buffers.size();
|
||||
const auto tflite_model_operator_codes_size = tflite_model_->operator_codes.size();
|
||||
|
||||
for (auto &subgraph : tflite_model_->subgraphs) {
|
||||
auto all_singraph_tensor_size = subgraph->tensors.size();
|
||||
if (std::any_of(subgraph->inputs.begin(), subgraph->inputs.end(), [&all_singraph_tensor_size](int32_t index) {
|
||||
return index >= static_cast<int32_t>(all_singraph_tensor_size) || index < 0;
|
||||
})) {
|
||||
MS_LOG(ERROR) << "tflite input illegal.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (std::any_of(subgraph->outputs.begin(), subgraph->outputs.end(), [&all_singraph_tensor_size](int32_t index) {
|
||||
return index >= static_cast<int32_t>(all_singraph_tensor_size) || index < 0;
|
||||
})) {
|
||||
MS_LOG(ERROR) << "tflite output illegal.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto &op : subgraph->operators) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "tflite contain nullptr op.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (op->opcode_index >= tflite_model_operator_codes_size) {
|
||||
MS_LOG(ERROR) << "op is not a tflite opcode";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &tensor : subgraph->tensors) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tflite model contain nullptr tensor.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (tensor->buffer >= tflite_model_buffers_size) {
|
||||
MS_LOG(ERROR) << "tflite tensor buffer index beyond upper limit.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
// load graph
|
||||
|
@ -66,7 +114,14 @@ api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto status = ConvertTfliteGraph();
|
||||
auto status = TfliteModelVerify();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "tflite model verify failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertTfliteGraph();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert tflite graph failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -46,10 +46,6 @@ class TfliteModelParser : public converter::ModelParser {
|
|||
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
private:
|
||||
std::unique_ptr<tflite::ModelT> tflite_model_;
|
||||
std::map<int, CNodePtr> control_flow_nodes_;
|
||||
std::map<CNodePtr, std::pair<FuncGraphPtr, FuncGraphPtr>> control_flow_map_;
|
||||
char *tflite_model_buf_ = nullptr;
|
||||
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const std::string &model_path);
|
||||
STATUS ConvertConstTensor(const std::unique_ptr<tflite::TensorT> &tensor, const ParameterPtr ¶meter,
|
||||
const std::string &tensor_name, bool is_uint8_weight_quant);
|
||||
|
@ -73,6 +69,13 @@ class TfliteModelParser : public converter::ModelParser {
|
|||
ops::PrimitiveC *primitive_c);
|
||||
static STATUS SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
|
||||
std::vector<QuantParamT> *quant_params, int round_type = 1);
|
||||
STATUS TfliteModelVerify();
|
||||
|
||||
private:
|
||||
std::unique_ptr<tflite::ModelT> tflite_model_;
|
||||
std::map<int, CNodePtr> control_flow_nodes_;
|
||||
std::map<CNodePtr, std::pair<FuncGraphPtr, FuncGraphPtr>> control_flow_map_;
|
||||
char *tflite_model_buf_ = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue