[MS][LITE] Add build function for model_path
This commit is contained in:
parent
5491b2c982
commit
88b670611e
|
@ -38,7 +38,6 @@ namespace dataset {
|
|||
class Dataset;
|
||||
} // namespace dataset
|
||||
|
||||
|
||||
class MS_API Model {
|
||||
public:
|
||||
Model();
|
||||
|
@ -73,6 +72,9 @@ class MS_API Model {
|
|||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
Status Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
|
|
@ -72,6 +72,12 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
|||
return kMCFailed;
|
||||
}
|
||||
|
||||
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
|
|
|
@ -40,6 +40,20 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
Status ret = impl_->Build(model_path, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||
std::stringstream err_msg;
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "src/cxx_api/tensor_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/train/train_session.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
@ -62,6 +63,33 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &ms_context) {
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
size_t data_size = 0;
|
||||
char *model_data = lite::ReadFile(model_path.c_str(), &data_size);
|
||||
if (model_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model path failed.";
|
||||
return kMDFileNotExist;
|
||||
}
|
||||
|
||||
auto session = std::shared_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateSession(const_cast<const char *>(model_data), data_size, &lite_context));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
session_.swap(session);
|
||||
MS_LOG(DEBUG) << "Build model success.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::Build() {
|
||||
MS_LOG(DEBUG) << "Start build model.";
|
||||
if (graph_ == nullptr || graph_->graph_data_ == nullptr) {
|
||||
|
|
|
@ -60,6 +60,7 @@ class ModelImpl {
|
|||
Status Build();
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context);
|
||||
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
|
||||
|
|
Loading…
Reference in New Issue