diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 4a0c74546b9..e968c5322ad 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -56,6 +56,54 @@ std::unique_ptr TfliteModelParser::ReadTfliteModel(const std::st return tflite::UnPackModel(tflite_model_buf_); } +STATUS TfliteModelParser::TfliteModelVerify() { + if (tflite_model_->subgraphs.empty()) { + MS_LOG(ERROR) << "tflite model does not has a main graph."; + return RET_ERROR; + } + const auto tflite_model_buffers_size = tflite_model_->buffers.size(); + const auto tflite_model_operator_codes_size = tflite_model_->operator_codes.size(); + + for (auto &subgraph : tflite_model_->subgraphs) { + auto all_singraph_tensor_size = subgraph->tensors.size(); + if (std::any_of(subgraph->inputs.begin(), subgraph->inputs.end(), [&all_singraph_tensor_size](int32_t index) { + return index >= static_cast(all_singraph_tensor_size) || index < 0; + })) { + MS_LOG(ERROR) << "tflite input illegal."; + return RET_ERROR; + } + if (std::any_of(subgraph->outputs.begin(), subgraph->outputs.end(), [&all_singraph_tensor_size](int32_t index) { + return index >= static_cast(all_singraph_tensor_size) || index < 0; + })) { + MS_LOG(ERROR) << "tflite output illegal."; + return RET_ERROR; + } + for (auto &op : subgraph->operators) { + if (op == nullptr) { + MS_LOG(ERROR) << "tflite contain nullptr op."; + return RET_ERROR; + } + if (op->opcode_index >= tflite_model_operator_codes_size) { + MS_LOG(ERROR) << "op is not a tflite opcode"; + return RET_ERROR; + } + } + + for (auto &tensor : subgraph->tensors) { + if (tensor == nullptr) { + MS_LOG(ERROR) << "tflite model contain nullptr tensor."; + return RET_ERROR; + } + if (tensor->buffer >= tflite_model_buffers_size) { + MS_LOG(ERROR) << "tflite tensor buffer index beyond upper limit."; + return RET_ERROR; + } + } + } + + return RET_OK; +} + api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) { auto model_file = flag.model_file; // load graph @@ -66,7 +114,14 @@ api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters return nullptr; } - auto status = ConvertTfliteGraph(); + auto status = TfliteModelVerify(); + if (status != RET_OK) { + MS_LOG(ERROR) << "tflite model verify failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + + status = ConvertTfliteGraph(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert tflite graph failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 36da058334f..6848178d995 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -46,10 +46,6 @@ class TfliteModelParser : public converter::ModelParser { static int Tflite2AnfAdjust(const std::set &all_func_graphs); private: - std::unique_ptr tflite_model_; - std::map control_flow_nodes_; - std::map> control_flow_map_; - char *tflite_model_buf_ = nullptr; std::unique_ptr ReadTfliteModel(const std::string &model_path); STATUS ConvertConstTensor(const std::unique_ptr &tensor, const ParameterPtr ¶meter, const std::string &tensor_name, bool is_uint8_weight_quant); @@ -73,6 +69,13 @@ class TfliteModelParser : public converter::ModelParser { ops::PrimitiveC *primitive_c); static STATUS SetTensorQuantParam(const std::unique_ptr &tflite_tensor, std::vector *quant_params, int round_type = 1); + STATUS TfliteModelVerify(); + + private: + std::unique_ptr tflite_model_; + std::map control_flow_nodes_; + std::map> control_flow_map_; + char *tflite_model_buf_ = nullptr; }; } // namespace lite } // namespace mindspore