add warning/error logs of ExportModel API in lite-training

This commit is contained in:
zhangyanhui 2023-01-10 14:34:43 +08:00
parent 5c709121af
commit 2f6b023d85
5 changed files with 79 additions and 47 deletions

View File

@ -94,7 +94,7 @@ class MS_API Serialization {
///
/// \param[in] model The model data.
/// \param[in] model_type The model file type.
/// \param[in] model_file The exported model file.
/// \param[in] model_file The path of exported model file.
/// \param[in] quantization_type The quantification type.
/// \param[in] export_inference_only Whether to export a reasoning only model.
/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as

View File

@ -24,6 +24,7 @@
#include "include/api/cfg.h"
#include "include/train/train_cfg.h"
#include "src/litert/inner_context.h"
#include "src/common/log_adapter.h"
namespace mindspore {
class ContextUtils {
@ -56,6 +57,9 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
if (qt == kWeightQuant) {
return lite::QT_WEIGHT;
}
if (qt == kFullQuant || qt == kUnknownQuantType) {
MS_LOG(WARNING) << "QuantizationType " << qt << " does not support, set the quantizationType to default.";
}
return lite::QT_DEFAULT;
}

View File

@ -61,7 +61,7 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
return lite::QT_WEIGHT;
}
if (qt == kFullQuant || qt == kUnknownQuantType) {
MS_LOG(WARNING) << qt << " does not support, set the quantizationType to default.";
MS_LOG(WARNING) << "QuantizationType " << qt << " does not support, set the quantizationType to default.";
}
return lite::QT_DEFAULT;
}

View File

@ -1204,11 +1204,71 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke
return RET_OK;
}
template <typename DestType>
int TrainSession::ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type,
bool orig_train_state, std::vector<std::string> output_tensor_name) {
TrainExport texport(destination);
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
if (!output_tensor_name.empty() && model_type == MT_INFERENCE) {
std::vector<kernel::KernelExec *> export_kernels = {};
status = FindExportKernels(&export_kernels, output_tensor_name, inference_kernels_);
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
status = texport.ExportNet(export_kernels, tensors_, output_tensor_name, model_.get(), quant_type);
} else {
if (!output_tensor_name.empty() && model_type == MT_TRAIN) {
MS_LOG(WARNING) << "Train model does not support to export selected output tensor, and all of the train kernels "
"tensors will be exported";
}
if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
})) {
status = texport.SaveModel(model_.get(), destination);
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Failed to save model");
if (orig_train_state) {
status = Train();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed.");
}
return status;
} else {
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
model_.get(), quant_type);
}
}
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
if (model_type == MT_INFERENCE) {
status = texport.TrainModelDrop();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
status = texport.TrainModelFusion();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
}
if constexpr (std::is_same_v<DestType, const std::string &>) {
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << destination;
return status;
}
} else {
status = texport.SaveToBuffer();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
}
return RET_OK;
}
template <typename DestType>
int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
FormatType format, std::vector<std::string> out_put_tensor_name) {
if constexpr (std::is_same_v<DestType, const std::string &>) {
MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
struct stat path_type;
if (stat(destination.c_str(), &path_type) == RET_OK) {
if (path_type.st_mode & S_IFDIR) {
MS_LOG(ERROR) << "Destination must be path, now is a directory";
return RET_ERROR;
}
}
} else if constexpr (std::is_same_v<DestType, Buffer *>) {
MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
} else {
@ -1222,53 +1282,18 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti
MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty");
bool orig_train_state = IsTrain();
Eval();
TrainExport texport(destination);
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
std::vector<kernel::KernelExec *> export_kernels = {};
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else {
if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
})) {
status = texport.SaveModel(model_.get(), destination);
if (orig_train_state) Train();
return status;
} else {
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
model_.get(), quant_type);
}
int status = Eval();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed");
status = ExportByDifferentType<DestType>(destination, model_type, quant_type, orig_train_state, out_put_tensor_name);
if (status != RET_OK) {
MS_LOG(ERROR) << "Fail to export by different type";
return status;
}
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
if (model_type == MT_INFERENCE) {
status = texport.TrainModelDrop();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
status = texport.TrainModelFusion();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
if (orig_train_state) {
status = Train();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed");
}
if constexpr (std::is_same_v<DestType, const std::string &>) {
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << destination;
return status;
}
} else {
status = texport.SaveToBuffer();
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
}
if (orig_train_state) Train();
return status;
return RET_OK;
}
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,

View File

@ -170,6 +170,9 @@ class TrainSession : virtual public lite::LiteSession {
const std::unordered_map<lite::Tensor *, size_t> &offset_map,
std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx);
template <typename DestType>
int ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type,
bool orig_train_state, std::vector<std::string> output_tensor_name = {});
template <typename DestType>
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {});
std::map<Tensor *, Tensor *> restored_origin_tensors_;