!46158 Add Lite API for exporting model to buffer

Merge pull request !46158 from zuochuanyong/lite_export_model
This commit is contained in:
i-robot 2022-12-05 07:37:44 +00:00 committed by Gitee
commit 3f2a199b81
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 225 additions and 46 deletions

View File

@ -79,10 +79,16 @@ class MS_API Serialization {
/// ///
/// \param[in] model The model data. /// \param[in] model The model data.
/// \param[in] model_type The model file type. /// \param[in] model_type The model file type.
/// \param[out] model_data The model parameter data. /// \param[out] model_data The model buffer.
/// \param[in] quantization_type The quantification type.
/// \param[in] export_inference_only Whether to export a reasoning only model.
/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
/// empty, and export the complete reasoning model.
/// ///
/// \return Status. /// \return Status.
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
const std::vector<std::string> &output_tensor_name = {});
/// \brief Export training model from file. /// \brief Export training model from file.
/// ///
@ -111,6 +117,9 @@ class MS_API Serialization {
static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file, static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
QuantizationType quantization_type, bool export_inference_only, QuantizationType quantization_type, bool export_inference_only,
const std::vector<std::vector<char>> &output_tensor_name); const std::vector<std::vector<char>> &output_tensor_name);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
QuantizationType quantization_type, bool export_inference_only,
const std::vector<std::vector<char>> &output_tensor_name);
}; };
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
@ -139,5 +148,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
VectorStringToChar(output_tensor_name)); VectorStringToChar(output_tensor_name));
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
QuantizationType quantization_type, bool export_inference_only,
const std::vector<std::string> &output_tensor_name) {
return ExportModel(model, model_type, model_data, quantization_type, export_inference_only,
VectorStringToChar(output_tensor_name));
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -333,7 +333,8 @@ Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &,
return kMEFailed; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &, ModelType, Buffer *) { Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
const std::vector<std::vector<char>> & /* output_tensor_name */) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed; return kMEFailed;
} }

View File

@ -331,7 +331,8 @@ Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &,
return kMEFailed; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &, ModelType, Buffer *) { Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
const std::vector<std::vector<char>> & /* output_tensor_name */) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed; return kMEFailed;
} }

View File

@ -159,9 +159,34 @@ Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &p
return kMEFailed; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) { Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
MS_LOG(ERROR) << "Unsupported feature."; QuantizationType quantization_type, bool export_inference_only,
return kMEFailed; const std::vector<std::vector<char>> &output_tensor_name) {
if (model.impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteUninitializedObj;
}
if (!model.impl_->IsTrainModel()) {
MS_LOG(ERROR) << "Model is not TrainModel.";
return kLiteError;
}
if (model_data == nullptr) {
MS_LOG(ERROR) << "model_data is nullptr.";
return kLiteParamInvalid;
}
if (model_type != kMindIR && model_type != kMindIR_Lite) {
MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
return kLiteParamInvalid;
}
if (model.impl_->session_ == nullptr) {
MS_LOG(ERROR) << "Model session is nullptr.";
return kLiteError;
}
auto ret = model.impl_->session_->Export(model_data, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS,
VectorCharToString(output_tensor_name));
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file, Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,

View File

@ -101,6 +101,11 @@ class LiteSession {
std::vector<std::string> out_put_tensor_name = {}) { std::vector<std::string> out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR; return mindspore::lite::RET_ERROR;
} }
virtual int Export(Buffer *model_buffer, lite::ModelType model_type = lite::MT_TRAIN,
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
std::vector<std::string> out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR;
}
virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; } virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; }
virtual std::vector<lite::Tensor *> GetFeatureMaps() const { virtual std::vector<lite::Tensor *> GetFeatureMaps() const {
std::vector<lite::Tensor *> features; std::vector<lite::Tensor *> features;

View File

@ -612,8 +612,39 @@ int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) {
return status; return status;
} }
int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) {
MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty.");
MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty.");
auto *liteModel = reinterpret_cast<LiteModel *>(model);
auto size = liteModel->buf_size_;
model_buffer->ResizeData(size);
size_t out_size = model_buffer->DataSize();
int status = mindspore::lite::Model::Export(model, static_cast<char *>(model_buffer->MutableData()), &out_size);
if (out_size != size) {
MS_LOG(ERROR) << "model_buffer resize failed.";
return RET_ERROR;
}
return status;
}
int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); } int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
int TrainExport::SaveToBuffer() {
constexpr size_t kFbBuilderInitSize = 1024;
flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize);
auto offset = schema::MetaGraph::Pack(builder, meta_graph_);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
auto content = builder.GetBufferPointer();
MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty.");
MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty.");
model_buffer_->SetData(content, size);
return RET_OK;
}
bool TrainExport::IsInputTensor(const schema::TensorT &t) { bool TrainExport::IsInputTensor(const schema::TensorT &t) {
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>()); int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
return ((t.data.size() == 0) && (total_dims != 0)); return ((t.data.size() == 0) && (total_dims != 0));

View File

@ -44,23 +44,27 @@ struct tensor_info {
class TrainExport { class TrainExport {
public: public:
explicit TrainExport(const std::string file_name) : file_name_(file_name) {} explicit TrainExport(const std::string file_name) : file_name_(file_name) {}
explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {}
virtual ~TrainExport(); virtual ~TrainExport();
int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names, const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names,
const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr); const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr);
int ExportInit(const std::string model_name, std::string version); int ExportInit(const std::string model_name, std::string version);
int SaveToFile(); int SaveToFile();
int SaveToBuffer();
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
int LoadModel(void *buf, size_t buf_size); int LoadModel(void *buf, size_t buf_size);
int AddTransformNode(); int AddTransformNode();
int TrainModelFusion(); int TrainModelFusion();
int TrainModelDrop(); int TrainModelDrop();
int SaveModel(lite::Model *model, const std::string &file_name); int SaveModel(lite::Model *model, const std::string &file_name);
int SaveModel(lite::Model *model, Buffer *model_buffer);
protected: protected:
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);
private: private:
Buffer *model_buffer_ = nullptr;
std::string file_name_; std::string file_name_;
schema::MetaGraphT *meta_graph_ = nullptr; schema::MetaGraphT *meta_graph_ = nullptr;
std::vector<size_t> out_idx_; std::vector<size_t> out_idx_;

View File

@ -1154,9 +1154,17 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke
return RET_OK; return RET_OK;
} }
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, template <typename DestType>
FormatType format, std::vector<std::string> out_put_tensor_name) { int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty"); FormatType format, std::vector<std::string> out_put_tensor_name) {
if constexpr (std::is_same_v<DestType, const std::string &>) {
MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
} else if constexpr (std::is_same_v<DestType, Buffer *>) {
MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
} else {
MS_LOG(ERROR) << "Unsupported destination.";
return RET_ERROR;
}
MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR, MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR,
"Export model type parameter error"); "Export model type parameter error");
MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR, MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR,
@ -1165,27 +1173,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
bool orig_train_state = IsTrain(); bool orig_train_state = IsTrain();
Eval(); Eval();
TrainExport texport(file_name); TrainExport texport(destination);
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
if (status != RET_OK) { TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
MS_LOG(ERROR) << "cannot init export";
return status;
}
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) { if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
std::vector<kernel::KernelExec *> export_kernels = {}; std::vector<kernel::KernelExec *> export_kernels = {};
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_); status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
if (status != RET_OK) { TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
MS_LOG(ERROR) << "FindExportKernels failed.";
return RET_ERROR;
}
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type); status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else { } else {
if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) && if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
})) { })) {
status = texport.SaveModel(model_.get(), file_name); status = texport.SaveModel(model_.get(), destination);
if (orig_train_state) Train(); if (orig_train_state) Train();
return status; return status;
} else { } else {
@ -1194,31 +1196,41 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
model_.get(), quant_type); model_.get(), quant_type);
} }
} }
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot export Network";
return status;
}
if (model_type == MT_INFERENCE) { if (model_type == MT_INFERENCE) {
status = texport.TrainModelDrop(); status = texport.TrainModelDrop();
if (status != RET_OK) { TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
MS_LOG(ERROR) << "TrainModelDrop failed.";
return status;
}
status = texport.TrainModelFusion(); status = texport.TrainModelFusion();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
}
if constexpr (std::is_same_v<DestType, const std::string &>) {
status = texport.SaveToFile();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "TrainModelFusion failed."; MS_LOG(ERROR) << "failed to save to " << destination;
return status; return status;
} }
} else {
status = texport.SaveToBuffer();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
} }
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name;
return status;
}
if (orig_train_state) Train(); if (orig_train_state) Train();
return status; return status;
} }
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
FormatType format, std::vector<std::string> out_put_tensor_name) {
return ExportInner<const std::string &>(file_name, model_type, quant_type, format, out_put_tensor_name);
}
int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
std::vector<std::string> out_put_tensor_name) {
return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
}
std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const { std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
std::vector<lite::Tensor *> features; std::vector<lite::Tensor *> features;
for (auto cur_tensor : this->tensors_) { for (auto cur_tensor : this->tensors_) {

View File

@ -36,6 +36,14 @@
+-------------------------------+ +-------------------------------+
*/ */
#define TRAIN_SESSION_CHECK_FALSE_MSG(value, errcode, msg) \
do { \
if ((value)) { \
MS_LOG(ERROR) << #msg; \
return errcode; \
} \
} while (0)
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>;
@ -90,7 +98,8 @@ class TrainSession : virtual public lite::LiteSession {
} }
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override; std::vector<std::string> out_put_tensor_name = {}) override;
int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override;
std::vector<lite::Tensor *> GetFeatureMaps() const override; std::vector<lite::Tensor *> GetFeatureMaps() const override;
int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) override; int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) override;
@ -158,7 +167,9 @@ class TrainSession : virtual public lite::LiteSession {
size_t GetInplaceTensorOffset(kernel::KernelExec *kernel, size_t GetInplaceTensorOffset(kernel::KernelExec *kernel,
const std::unordered_map<lite::Tensor *, size_t> &offset_map, const std::unordered_map<lite::Tensor *, size_t> &offset_map,
std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx); std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx);
template <typename DestType>
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {});
std::map<Tensor *, Tensor *> restored_origin_tensors_; std::map<Tensor *, Tensor *> restored_origin_tensors_;
int virtual_batch_idx_ = 0; int virtual_batch_idx_ = 0;
int virtual_batch_multiplier_ = 0; int virtual_batch_multiplier_ = 0;

View File

@ -183,15 +183,24 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
return map; return map;
} }
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type, template <typename DestType>
FormatType format, std::vector<std::string> out_put_tensor_name) { int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
FormatType format, std::vector<std::string> out_put_tensor_name) {
if constexpr (std::is_same_v<DestType, const std::string &>) {
MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
} else if constexpr (std::is_same_v<DestType, Buffer *>) {
MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
} else {
MS_LOG(ERROR) << "Unsupported destination.";
return RET_ERROR;
}
if (format != FT_FLATBUFFERS) { if (format != FT_FLATBUFFERS) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported"; MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR; return RET_ERROR;
} }
if (model_type == MT_TRAIN) { if (model_type == MT_TRAIN) {
return TrainSession::Export(filename, model_type, quant_type, format); return TrainSession::Export(destination, model_type, quant_type, format);
} }
bool orig_train_state = IsTrain(); bool orig_train_state = IsTrain();
@ -199,7 +208,7 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
MS_LOG(ERROR) << "eval failed."; MS_LOG(ERROR) << "eval failed.";
return RET_ERROR; return RET_ERROR;
} }
TrainExport texport(filename); TrainExport texport(destination);
int status = texport.LoadModel(lite_model_, size_backbone_); int status = texport.LoadModel(lite_model_, size_backbone_);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "cannot init export"; MS_LOG(ERROR) << "cannot init export";
@ -231,11 +240,18 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
MS_LOG(ERROR) << "cannot serialize head"; MS_LOG(ERROR) << "cannot serialize head";
return status; return status;
} }
status = texport.SaveToFile();
if (status != RET_OK) { if constexpr (std::is_same_v<DestType, const std::string &>) {
MS_LOG(ERROR) << "failed to save to " << filename; status = texport.SaveToFile();
return status; if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << destination;
return status;
}
} else {
status = texport.SaveToBuffer();
MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
} }
if (orig_train_state) { if (orig_train_state) {
auto ret = Train(); auto ret = Train();
if (ret != RET_OK) { if (ret != RET_OK) {
@ -246,6 +262,16 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
return status; return status;
} }
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
FormatType format, std::vector<std::string> out_put_tensor_name) {
return ExportInner<const std::string &>(filename, model_type, quant_type, format, out_put_tensor_name);
}
int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
std::vector<std::string> out_put_tensor_name) {
return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
}
lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
const char *model_buf_head, size_t size_head, const char *model_buf_head, size_t size_head,
const std::shared_ptr<InnerContext> &context, bool train_mode, const std::shared_ptr<InnerContext> &context, bool train_mode,

View File

@ -65,6 +65,8 @@ class TransferSession : public lite::TrainSession {
int CompileTransferGraph(); int CompileTransferGraph();
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override; std::vector<std::string> out_put_tensor_name = {}) override;
int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override;
protected: protected:
LiteSession *backbone_session_ = nullptr; LiteSession *backbone_session_ = nullptr;
@ -74,6 +76,9 @@ class TransferSession : public lite::TrainSession {
bool is_valid_ = false; bool is_valid_ = false;
private: private:
template <typename DestType>
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {});
bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len); bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len);
std::unordered_map<size_t, size_t> ConnectionMap(); std::unordered_map<size_t, size_t> ConnectionMap();
bool nchw2nhwc_ = false; bool nchw2nhwc_ = false;

View File

@ -1,2 +1,2 @@
Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!! Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!!
1116516 1124800

View File

@ -1,2 +1,2 @@
Note: This is the mindspore Lite inference framework size threshold. Modifying this threshold requires meeting review. Note: This is the mindspore Lite inference framework size threshold. Modifying this threshold requires meeting review.
1116516 1124800

View File

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory> #include <memory>
#include <string>
#include <iostream>
#include "common/common_test.h" #include "common/common_test.h"
#include "include/api/serialization.h" #include "include/api/serialization.h"
@ -41,4 +43,44 @@ TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kMindIR, "./nets/export.ms") != kSuccess); ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kMindIR, "./nets/export.ms") != kSuccess);
} }
TEST_F(TestCxxApiLiteSerialization, test_export_to_buffer) {
auto context = std::make_shared<mindspore::Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
Graph graph;
std::string file_name = "../../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_train.ms";
auto status = mindspore::Serialization::Load(file_name, mindspore::kMindIR, &graph);
ASSERT_TRUE(status == mindspore::kSuccess);
Model model;
auto cfg = std::make_shared<mindspore::TrainCfg>();
status = model.Build(mindspore::GraphCell(graph), context, cfg);
ASSERT_TRUE(status == mindspore::kSuccess);
std::string exported_file = "./export.ms";
status = Serialization::ExportModel(model, mindspore::kMindIR, exported_file, mindspore::kNoQuant, false);
ASSERT_TRUE(status == mindspore::kSuccess);
mindspore::Buffer modef_buffer_infer;
status = Serialization::ExportModel(model, mindspore::kMindIR, &modef_buffer_infer, mindspore::kNoQuant, false);
ASSERT_TRUE(status == mindspore::kSuccess);
std::ifstream file(exported_file.c_str(), std::ifstream::binary);
ASSERT_TRUE(file);
file.seekg(0, std::ifstream::end);
size_t file_size = file.tellg();
file.seekg(0);
const int kMaxSize = 1024 * 1024;
char buf[kMaxSize] = {0};
file.read(buf, file_size);
file.close();
int result = memcmp(buf, modef_buffer_infer.Data(), file_size);
ASSERT_EQ(result, 0);
}
} // namespace mindspore } // namespace mindspore