[MS][LITE] Add build function for model_path

This commit is contained in:
cjh9368 2021-07-12 10:58:50 +08:00
parent 5491b2c982
commit 88b670611e
5 changed files with 52 additions and 1 deletions

View File

@ -38,7 +38,6 @@ namespace dataset {
class Dataset;
} // namespace dataset
class MS_API Model {
public:
Model();
@ -73,6 +72,9 @@ class MS_API Model {
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
private:
friend class Serialization;

View File

@ -72,6 +72,12 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
return kMCFailed;
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed;
}
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Failed because this model has not been built.";

View File

@ -40,6 +40,20 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
return kSuccess;
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
}
Status ret = impl_->Build(model_path, model_type, model_context);
if (ret != kSuccess) {
return ret;
}
return kSuccess;
}
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
const std::shared_ptr<TrainCfg> &train_cfg) {
std::stringstream err_msg;

View File

@ -28,6 +28,7 @@
#include "src/cxx_api/tensor_utils.h"
#include "src/common/log_adapter.h"
#include "src/train/train_session.h"
#include "src/common/file_utils.h"
namespace mindspore {
using mindspore::lite::RET_ERROR;
@ -62,6 +63,33 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
return kSuccess;
}
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &ms_context) {
lite::Context lite_context;
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
if (status != kSuccess) {
return status;
}
size_t data_size = 0;
char *model_data = lite::ReadFile(model_path.c_str(), &data_size);
if (model_data == nullptr) {
MS_LOG(ERROR) << "Read model path failed.";
return kMDFileNotExist;
}
auto session = std::shared_ptr<session::LiteSession>(
session::LiteSession::CreateSession(const_cast<const char *>(model_data), data_size, &lite_context));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
session_.swap(session);
MS_LOG(DEBUG) << "Build model success.";
return kSuccess;
}
Status ModelImpl::Build() {
MS_LOG(DEBUG) << "Start build model.";
if (graph_ == nullptr || graph_->graph_data_ == nullptr) {

View File

@ -60,6 +60,7 @@ class ModelImpl {
Status Build();
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context);
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,