!46158 Add Lite API for exporting model to buffer
Merge pull request !46158 from zuochuanyong/lite_export_model
This commit is contained in:
commit
3f2a199b81
|
@ -79,10 +79,16 @@ class MS_API Serialization {
|
|||
///
|
||||
/// \param[in] model The model data.
|
||||
/// \param[in] model_type The model file type.
|
||||
/// \param[out] model_data The model parameter data.
|
||||
/// \param[out] model_data The model buffer.
|
||||
/// \param[in] quantization_type The quantification type.
|
||||
/// \param[in] export_inference_only Whether to export a reasoning only model.
|
||||
/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
|
||||
/// empty, and export the complete reasoning model.
|
||||
///
|
||||
/// \return Status.
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
|
||||
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
|
||||
const std::vector<std::string> &output_tensor_name = {});
|
||||
|
||||
/// \brief Export training model from file.
|
||||
///
|
||||
|
@ -111,6 +117,9 @@ class MS_API Serialization {
|
|||
static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
|
||||
QuantizationType quantization_type, bool export_inference_only,
|
||||
const std::vector<std::vector<char>> &output_tensor_name);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
|
||||
QuantizationType quantization_type, bool export_inference_only,
|
||||
const std::vector<std::vector<char>> &output_tensor_name);
|
||||
};
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
|
@ -139,5 +148,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
|
|||
VectorStringToChar(output_tensor_name));
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
|
||||
QuantizationType quantization_type, bool export_inference_only,
|
||||
const std::vector<std::string> &output_tensor_name) {
|
||||
return ExportModel(model, model_type, model_data, quantization_type, export_inference_only,
|
||||
VectorStringToChar(output_tensor_name));
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
|
|
@ -333,7 +333,8 @@ Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &,
|
|||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
|
||||
Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
|
||||
const std::vector<std::vector<char>> & /* output_tensor_name */) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
}
|
||||
|
|
|
@ -331,7 +331,8 @@ Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &,
|
|||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
|
||||
Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
|
||||
const std::vector<std::vector<char>> & /* output_tensor_name */) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
}
|
||||
|
|
|
@ -159,9 +159,34 @@ Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &p
|
|||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
|
||||
QuantizationType quantization_type, bool export_inference_only,
|
||||
const std::vector<std::vector<char>> &output_tensor_name) {
|
||||
if (model.impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
if (!model.impl_->IsTrainModel()) {
|
||||
MS_LOG(ERROR) << "Model is not TrainModel.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (model_data == nullptr) {
|
||||
MS_LOG(ERROR) << "model_data is nullptr.";
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
if (model_type != kMindIR && model_type != kMindIR_Lite) {
|
||||
MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
if (model.impl_->session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model session is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto ret = model.impl_->session_->Export(model_data, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
|
||||
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS,
|
||||
VectorCharToString(output_tensor_name));
|
||||
|
||||
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
|
||||
|
|
|
@ -101,6 +101,11 @@ class LiteSession {
|
|||
std::vector<std::string> out_put_tensor_name = {}) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
virtual int Export(Buffer *model_buffer, lite::ModelType model_type = lite::MT_TRAIN,
|
||||
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
|
||||
std::vector<std::string> out_put_tensor_name = {}) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; }
|
||||
virtual std::vector<lite::Tensor *> GetFeatureMaps() const {
|
||||
std::vector<lite::Tensor *> features;
|
||||
|
|
|
@ -612,8 +612,39 @@ int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) {
|
|||
return status;
|
||||
}
|
||||
|
||||
int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) {
|
||||
MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty.");
|
||||
MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty.");
|
||||
auto *liteModel = reinterpret_cast<LiteModel *>(model);
|
||||
auto size = liteModel->buf_size_;
|
||||
model_buffer->ResizeData(size);
|
||||
|
||||
size_t out_size = model_buffer->DataSize();
|
||||
int status = mindspore::lite::Model::Export(model, static_cast<char *>(model_buffer->MutableData()), &out_size);
|
||||
if (out_size != size) {
|
||||
MS_LOG(ERROR) << "model_buffer resize failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
|
||||
|
||||
int TrainExport::SaveToBuffer() {
|
||||
constexpr size_t kFbBuilderInitSize = 1024;
|
||||
flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph_);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
auto content = builder.GetBufferPointer();
|
||||
MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty.");
|
||||
MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty.");
|
||||
model_buffer_->SetData(content, size);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool TrainExport::IsInputTensor(const schema::TensorT &t) {
|
||||
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
|
||||
return ((t.data.size() == 0) && (total_dims != 0));
|
||||
|
|
|
@ -44,23 +44,27 @@ struct tensor_info {
|
|||
class TrainExport {
|
||||
public:
|
||||
explicit TrainExport(const std::string file_name) : file_name_(file_name) {}
|
||||
explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {}
|
||||
virtual ~TrainExport();
|
||||
int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels,
|
||||
const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names,
|
||||
const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr);
|
||||
int ExportInit(const std::string model_name, std::string version);
|
||||
int SaveToFile();
|
||||
int SaveToBuffer();
|
||||
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
|
||||
int LoadModel(void *buf, size_t buf_size);
|
||||
int AddTransformNode();
|
||||
int TrainModelFusion();
|
||||
int TrainModelDrop();
|
||||
int SaveModel(lite::Model *model, const std::string &file_name);
|
||||
int SaveModel(lite::Model *model, Buffer *model_buffer);
|
||||
|
||||
protected:
|
||||
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);
|
||||
|
||||
private:
|
||||
Buffer *model_buffer_ = nullptr;
|
||||
std::string file_name_;
|
||||
schema::MetaGraphT *meta_graph_ = nullptr;
|
||||
std::vector<size_t> out_idx_;
|
||||
|
|
|
@ -1154,9 +1154,17 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty");
|
||||
template <typename DestType>
|
||||
int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
|
||||
} else if constexpr (std::is_same_v<DestType, Buffer *>) {
|
||||
MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported destination.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR,
|
||||
"Export model type parameter error");
|
||||
MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR,
|
||||
|
@ -1165,27 +1173,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
|
|||
|
||||
bool orig_train_state = IsTrain();
|
||||
Eval();
|
||||
TrainExport texport(file_name);
|
||||
TrainExport texport(destination);
|
||||
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot init export";
|
||||
return status;
|
||||
}
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
|
||||
|
||||
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
|
||||
std::vector<kernel::KernelExec *> export_kernels = {};
|
||||
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FindExportKernels failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
|
||||
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
|
||||
} else {
|
||||
if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
|
||||
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
|
||||
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
|
||||
})) {
|
||||
status = texport.SaveModel(model_.get(), file_name);
|
||||
status = texport.SaveModel(model_.get(), destination);
|
||||
if (orig_train_state) Train();
|
||||
return status;
|
||||
} else {
|
||||
|
@ -1194,31 +1196,41 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
|
|||
model_.get(), quant_type);
|
||||
}
|
||||
}
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
|
||||
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot export Network";
|
||||
return status;
|
||||
}
|
||||
if (model_type == MT_INFERENCE) {
|
||||
status = texport.TrainModelDrop();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainModelDrop failed.";
|
||||
return status;
|
||||
}
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
|
||||
|
||||
status = texport.TrainModelFusion();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainModelFusion failed.";
|
||||
MS_LOG(ERROR) << "failed to save to " << destination;
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
status = texport.SaveToBuffer();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
|
||||
}
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << file_name;
|
||||
return status;
|
||||
}
|
||||
|
||||
if (orig_train_state) Train();
|
||||
return status;
|
||||
}
|
||||
|
||||
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
return ExportInner<const std::string &>(file_name, model_type, quant_type, format, out_put_tensor_name);
|
||||
}
|
||||
|
||||
int TrainSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
|
||||
std::vector<std::string> out_put_tensor_name) {
|
||||
return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
|
||||
}
|
||||
|
||||
std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
|
||||
std::vector<lite::Tensor *> features;
|
||||
for (auto cur_tensor : this->tensors_) {
|
||||
|
|
|
@ -36,6 +36,14 @@
|
|||
+-------------------------------+
|
||||
*/
|
||||
|
||||
#define TRAIN_SESSION_CHECK_FALSE_MSG(value, errcode, msg) \
|
||||
do { \
|
||||
if ((value)) { \
|
||||
MS_LOG(ERROR) << #msg; \
|
||||
return errcode; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>;
|
||||
|
@ -90,7 +98,8 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
}
|
||||
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||
|
||||
int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||
std::vector<lite::Tensor *> GetFeatureMaps() const override;
|
||||
|
||||
int UpdateFeatureMaps(const std::vector<lite::Tensor *> &features_map) override;
|
||||
|
@ -158,7 +167,9 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
size_t GetInplaceTensorOffset(kernel::KernelExec *kernel,
|
||||
const std::unordered_map<lite::Tensor *, size_t> &offset_map,
|
||||
std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx);
|
||||
|
||||
template <typename DestType>
|
||||
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {});
|
||||
std::map<Tensor *, Tensor *> restored_origin_tensors_;
|
||||
int virtual_batch_idx_ = 0;
|
||||
int virtual_batch_multiplier_ = 0;
|
||||
|
|
|
@ -183,15 +183,24 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
|
|||
return map;
|
||||
}
|
||||
|
||||
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
template <typename DestType>
|
||||
int TransferSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
|
||||
} else if constexpr (std::is_same_v<DestType, Buffer *>) {
|
||||
MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported destination.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (format != FT_FLATBUFFERS) {
|
||||
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (model_type == MT_TRAIN) {
|
||||
return TrainSession::Export(filename, model_type, quant_type, format);
|
||||
return TrainSession::Export(destination, model_type, quant_type, format);
|
||||
}
|
||||
|
||||
bool orig_train_state = IsTrain();
|
||||
|
@ -199,7 +208,7 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
|
|||
MS_LOG(ERROR) << "eval failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
TrainExport texport(filename);
|
||||
TrainExport texport(destination);
|
||||
int status = texport.LoadModel(lite_model_, size_backbone_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot init export";
|
||||
|
@ -231,11 +240,18 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
|
|||
MS_LOG(ERROR) << "cannot serialize head";
|
||||
return status;
|
||||
}
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << filename;
|
||||
return status;
|
||||
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << destination;
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
status = texport.SaveToBuffer();
|
||||
MS_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
|
||||
}
|
||||
|
||||
if (orig_train_state) {
|
||||
auto ret = Train();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -246,6 +262,16 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
|
|||
return status;
|
||||
}
|
||||
|
||||
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
return ExportInner<const std::string &>(filename, model_type, quant_type, format, out_put_tensor_name);
|
||||
}
|
||||
|
||||
int TransferSession::Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType format,
|
||||
std::vector<std::string> out_put_tensor_name) {
|
||||
return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
|
||||
}
|
||||
|
||||
lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
|
||||
const char *model_buf_head, size_t size_head,
|
||||
const std::shared_ptr<InnerContext> &context, bool train_mode,
|
||||
|
|
|
@ -65,6 +65,8 @@ class TransferSession : public lite::TrainSession {
|
|||
int CompileTransferGraph();
|
||||
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||
int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||
|
||||
protected:
|
||||
LiteSession *backbone_session_ = nullptr;
|
||||
|
@ -74,6 +76,9 @@ class TransferSession : public lite::TrainSession {
|
|||
bool is_valid_ = false;
|
||||
|
||||
private:
|
||||
template <typename DestType>
|
||||
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {});
|
||||
bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len);
|
||||
std::unordered_map<size_t, size_t> ConnectionMap();
|
||||
bool nchw2nhwc_ = false;
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!!
|
||||
1116516
|
||||
1124800
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
Note: This is the mindspore Lite inference framework size threshold. Modifying this threshold requires meeting review.
|
||||
1116516
|
||||
1124800
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "common/common_test.h"
|
||||
#include "include/api/serialization.h"
|
||||
|
||||
|
@ -41,4 +43,44 @@ TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
|
|||
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kMindIR, "./nets/export.ms") != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiLiteSerialization, test_export_to_buffer) {
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
Graph graph;
|
||||
std::string file_name = "../../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_train.ms";
|
||||
auto status = mindspore::Serialization::Load(file_name, mindspore::kMindIR, &graph);
|
||||
ASSERT_TRUE(status == mindspore::kSuccess);
|
||||
|
||||
Model model;
|
||||
auto cfg = std::make_shared<mindspore::TrainCfg>();
|
||||
|
||||
status = model.Build(mindspore::GraphCell(graph), context, cfg);
|
||||
ASSERT_TRUE(status == mindspore::kSuccess);
|
||||
|
||||
std::string exported_file = "./export.ms";
|
||||
status = Serialization::ExportModel(model, mindspore::kMindIR, exported_file, mindspore::kNoQuant, false);
|
||||
ASSERT_TRUE(status == mindspore::kSuccess);
|
||||
|
||||
mindspore::Buffer modef_buffer_infer;
|
||||
status = Serialization::ExportModel(model, mindspore::kMindIR, &modef_buffer_infer, mindspore::kNoQuant, false);
|
||||
ASSERT_TRUE(status == mindspore::kSuccess);
|
||||
|
||||
std::ifstream file(exported_file.c_str(), std::ifstream::binary);
|
||||
ASSERT_TRUE(file);
|
||||
|
||||
file.seekg(0, std::ifstream::end);
|
||||
size_t file_size = file.tellg();
|
||||
file.seekg(0);
|
||||
|
||||
const int kMaxSize = 1024 * 1024;
|
||||
char buf[kMaxSize] = {0};
|
||||
file.read(buf, file_size);
|
||||
file.close();
|
||||
|
||||
int result = memcmp(buf, modef_buffer_infer.Data(), file_size);
|
||||
ASSERT_EQ(result, 0);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue