diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 217ad775131..9c0d1e37885 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -291,6 +291,11 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { } int LiteSession::CompileGraph(Model *model) { + if (!ModelVerify(*model)) { + MS_LOG(ERROR) << "wrong model input, please check"; + return RET_ERROR; + } + bool expected = false; if (!is_running_.compare_exchange_strong(expected, true)) { MS_LOG(ERROR) << "Not support multi-threading"; diff --git a/mindspore/lite/src/model_common.cc b/mindspore/lite/src/model_common.cc index 4ed019458b4..b7a682bac67 100644 --- a/mindspore/lite/src/model_common.cc +++ b/mindspore/lite/src/model_common.cc @@ -138,14 +138,7 @@ int SubGraphVerify(const Model &model) { return RET_OK; } -int ModelVerify(const Model &model, const int &schema_version) { - if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { - return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; - } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { - return NodeVerify(model) == RET_OK; - } - return RET_ERROR; -} +bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; } const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { if (buf == nullptr) { @@ -230,6 +223,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { return nullptr; } - return ModelVerify(*model, schema_version) ? model : nullptr; + return ModelVerify(*model) ? model : nullptr; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/model_common.h b/mindspore/lite/src/model_common.h index b51455129a7..51c9317e6ad 100644 --- a/mindspore/lite/src/model_common.h +++ b/mindspore/lite/src/model_common.h @@ -181,7 +181,7 @@ int NodeVerify(const Model &model); int SubGraphVerify(const Model &model); -int ModelVerify(const Model &model, const int &schema_version); +bool ModelVerify(const Model &model); const void *GetMetaGraphByVerison(const char *buf, const int &schema_version);