diff --git a/include/api/serialization.h b/include/api/serialization.h index b78dba95a77..a92ff3eb31e 100644 --- a/include/api/serialization.h +++ b/include/api/serialization.h @@ -94,7 +94,7 @@ class MS_API Serialization { /// /// \param[in] model The model data. /// \param[in] model_type The model file type. - /// \param[in] model_file The exported model file. + /// \param[in] model_file The path of exported model file. /// \param[in] quantization_type The quantification type. /// \param[in] export_inference_only Whether to export a reasoning only model. /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as diff --git a/mindspore/lite/src/extendrt/mock/lite_runtime/converters.h b/mindspore/lite/src/extendrt/mock/lite_runtime/converters.h index 8a7e67f8d36..8062d478753 100644 --- a/mindspore/lite/src/extendrt/mock/lite_runtime/converters.h +++ b/mindspore/lite/src/extendrt/mock/lite_runtime/converters.h @@ -24,6 +24,7 @@ #include "include/api/cfg.h" #include "include/train/train_cfg.h" #include "src/litert/inner_context.h" +#include "src/common/log_adapter.h" namespace mindspore { class ContextUtils { @@ -56,6 +57,9 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) { if (qt == kWeightQuant) { return lite::QT_WEIGHT; } + if (qt == kFullQuant || qt == kUnknownQuantType) { + MS_LOG(WARNING) << "QuantizationType " << qt << " does not support, set the quantizationType to default."; + } return lite::QT_DEFAULT; } diff --git a/mindspore/lite/src/litert/cxx_api/converters.h b/mindspore/lite/src/litert/cxx_api/converters.h index 3d1131b1de5..aa5bd309f06 100644 --- a/mindspore/lite/src/litert/cxx_api/converters.h +++ b/mindspore/lite/src/litert/cxx_api/converters.h @@ -61,7 +61,7 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) { return lite::QT_WEIGHT; } if (qt == kFullQuant || qt == kUnknownQuantType) { - MS_LOG(WARNING) << qt << " does not support, set the quantizationType to default."; + MS_LOG(WARNING) << "QuantizationType " << qt << " does not support, set the quantizationType to default."; } return lite::QT_DEFAULT; } diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index e09898dba8e..f9dc717efb7 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -1204,11 +1204,71 @@ int TrainSession::FindExportKernels(std::vector *export_ke return RET_OK; } +template +int TrainSession::ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type, + bool orig_train_state, std::vector output_tensor_name) { + TrainExport texport(destination); + int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); + if (!output_tensor_name.empty() && model_type == MT_INFERENCE) { + std::vector export_kernels = {}; + status = FindExportKernels(&export_kernels, output_tensor_name, inference_kernels_); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); + status = texport.ExportNet(export_kernels, tensors_, output_tensor_name, model_.get(), quant_type); + } else { + if (!output_tensor_name.empty() && model_type == MT_TRAIN) { + MS_LOG(WARNING) << "Train model does not support to export selected output tensor, and all of the train kernels " + "tensors will be exported"; + } + if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) && + std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { + return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; + })) { + status = texport.SaveModel(model_.get(), destination); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Failed to save model"); + if (orig_train_state) { + status = Train(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed."); + } + return status; + } else { + status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, + (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, + model_.get(), quant_type); + } + } + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); + if (model_type == MT_INFERENCE) { + status = texport.TrainModelDrop(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); + status = texport.TrainModelFusion(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed."); + } + if constexpr (std::is_same_v) { + status = texport.SaveToFile(); + if (status != RET_OK) { + MS_LOG(ERROR) << "failed to save to " << destination; + return status; + } + } else { + status = texport.SaveToBuffer(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer."); + } + return RET_OK; +} + template int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType format, std::vector out_put_tensor_name) { if constexpr (std::is_same_v) { MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty"); + struct stat path_type; + if (stat(destination.c_str(), &path_type) == RET_OK) { + if (path_type.st_mode & S_IFDIR) { + MS_LOG(ERROR) << "Destination must be path, now is a directory"; + return RET_ERROR; + } + } } else if constexpr (std::is_same_v) { MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr"); } else { @@ -1222,53 +1282,18 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty"); bool orig_train_state = IsTrain(); - Eval(); - TrainExport texport(destination); - int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); - TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); - - if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) { - std::vector export_kernels = {}; - status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_); - TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); - status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type); - } else { - if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) && - std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { - return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; - })) { - status = texport.SaveModel(model_.get(), destination); - if (orig_train_state) Train(); - return status; - } else { - status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, - (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, - model_.get(), quant_type); - } + int status = Eval(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed"); + status = ExportByDifferentType(destination, model_type, quant_type, orig_train_state, out_put_tensor_name); + if (status != RET_OK) { + MS_LOG(ERROR) << "Fail to export by different type"; + return status; } - TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); - - if (model_type == MT_INFERENCE) { - status = texport.TrainModelDrop(); - TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); - - status = texport.TrainModelFusion(); - TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed."); + if (orig_train_state) { + status = Train(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed"); } - - if constexpr (std::is_same_v) { - status = texport.SaveToFile(); - if (status != RET_OK) { - MS_LOG(ERROR) << "failed to save to " << destination; - return status; - } - } else { - status = texport.SaveToBuffer(); - TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer."); - } - - if (orig_train_state) Train(); - return status; + return RET_OK; } int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index 57f431e897c..ca98953f089 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -170,6 +170,9 @@ class TrainSession : virtual public lite::LiteSession { const std::unordered_map &offset_map, std::unordered_map *ref_count, uint32_t input_idx); template + int ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type, + bool orig_train_state, std::vector output_tensor_name = {}); + template int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType, std::vector out_put_tensor_name = {}); std::map restored_origin_tensors_;