forked from mindspore-Ecosystem/mindspore
add warning/error logs of ExportModel API in lite-training
This commit is contained in:
parent
5c709121af
commit
2f6b023d85
|
@ -94,7 +94,7 @@ class MS_API Serialization {
|
|||
///
|
||||
/// \param[in] model The model data.
|
||||
/// \param[in] model_type The model file type.
|
||||
/// \param[in] model_file The exported model file.
|
||||
/// \param[in] model_file The path of exported model file.
|
||||
/// \param[in] quantization_type The quantification type.
|
||||
/// \param[in] export_inference_only Whether to export a reasoning only model.
|
||||
/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "include/api/cfg.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "src/litert/inner_context.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ContextUtils {
|
||||
|
@ -56,6 +57,9 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
|
|||
if (qt == kWeightQuant) {
|
||||
return lite::QT_WEIGHT;
|
||||
}
|
||||
if (qt == kFullQuant || qt == kUnknownQuantType) {
|
||||
MS_LOG(WARNING) << "QuantizationType " << qt << " does not support, set the quantizationType to default.";
|
||||
}
|
||||
return lite::QT_DEFAULT;
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
|
|||
return lite::QT_WEIGHT;
|
||||
}
|
||||
if (qt == kFullQuant || qt == kUnknownQuantType) {
|
||||
MS_LOG(WARNING) << qt << " does not support, set the quantizationType to default.";
|
||||
MS_LOG(WARNING) << "QuantizationType " << qt << " does not support, set the quantizationType to default.";
|
||||
}
|
||||
return lite::QT_DEFAULT;
|
||||
}
|
||||
|
|
|
@ -1204,11 +1204,71 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename DestType>
|
||||
int TrainSession::ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type,
|
||||
bool orig_train_state, std::vector<std::string> output_tensor_name) {
|
||||
TrainExport texport(destination);
|
||||
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
|
||||
if (!output_tensor_name.empty() && model_type == MT_INFERENCE) {
|
||||
std::vector<kernel::KernelExec *> export_kernels = {};
|
||||
status = FindExportKernels(&export_kernels, output_tensor_name, inference_kernels_);
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
|
||||
status = texport.ExportNet(export_kernels, tensors_, output_tensor_name, model_.get(), quant_type);
|
||||
} else {
|
||||
if (!output_tensor_name.empty() && model_type == MT_TRAIN) {
|
||||
MS_LOG(WARNING) << "Train model does not support to export selected output tensor, and all of the train kernels "
|
||||
"tensors will be exported";
|
||||
}
|
||||
if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
|
||||
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
|
||||
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
|
||||
})) {
|
||||
status = texport.SaveModel(model_.get(), destination);
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Failed to save model");
|
||||
if (orig_train_state) {
|
||||
status = Train();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed.");
|
||||
}
|
||||
return status;
|
||||
} else {
|
||||
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
|
||||
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
|
||||
model_.get(), quant_type);
|
||||
}
|
||||
}
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
|
||||
if (model_type == MT_INFERENCE) {
|
||||
status = texport.TrainModelDrop();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
|
||||
status = texport.TrainModelFusion();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
|
||||
}
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << destination;
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
status = texport.SaveToBuffer();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename DestType>
|
||||
int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty");
|
||||
struct stat path_type;
|
||||
if (stat(destination.c_str(), &path_type) == RET_OK) {
|
||||
if (path_type.st_mode & S_IFDIR) {
|
||||
MS_LOG(ERROR) << "Destination must be path, now is a directory";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<DestType, Buffer *>) {
|
||||
MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr");
|
||||
} else {
|
||||
|
@ -1222,53 +1282,18 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti
|
|||
MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty");
|
||||
|
||||
bool orig_train_state = IsTrain();
|
||||
Eval();
|
||||
TrainExport texport(destination);
|
||||
int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
|
||||
|
||||
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
|
||||
std::vector<kernel::KernelExec *> export_kernels = {};
|
||||
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
|
||||
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
|
||||
} else {
|
||||
if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
|
||||
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
|
||||
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
|
||||
})) {
|
||||
status = texport.SaveModel(model_.get(), destination);
|
||||
if (orig_train_state) Train();
|
||||
return status;
|
||||
} else {
|
||||
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
|
||||
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
|
||||
model_.get(), quant_type);
|
||||
}
|
||||
int status = Eval();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed");
|
||||
status = ExportByDifferentType<DestType>(destination, model_type, quant_type, orig_train_state, out_put_tensor_name);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Fail to export by different type";
|
||||
return status;
|
||||
}
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
|
||||
|
||||
if (model_type == MT_INFERENCE) {
|
||||
status = texport.TrainModelDrop();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
|
||||
|
||||
status = texport.TrainModelFusion();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
|
||||
if (orig_train_state) {
|
||||
status = Train();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed");
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<DestType, const std::string &>) {
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << destination;
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
status = texport.SaveToBuffer();
|
||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer.");
|
||||
}
|
||||
|
||||
if (orig_train_state) Train();
|
||||
return status;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
|
||||
|
|
|
@ -170,6 +170,9 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
const std::unordered_map<lite::Tensor *, size_t> &offset_map,
|
||||
std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx);
|
||||
template <typename DestType>
|
||||
int ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type,
|
||||
bool orig_train_state, std::vector<std::string> output_tensor_name = {});
|
||||
template <typename DestType>
|
||||
int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||
std::vector<std::string> out_put_tensor_name = {});
|
||||
std::map<Tensor *, Tensor *> restored_origin_tensors_;
|
||||
|
|
Loading…
Reference in New Issue