!16387 [MS][LITE][TOD] Revert to previous TrainSession API (with additional Config Parameter)

From: @ehaleva
Reviewed-by: @HilbertDavid,@hangangqiang
Signed-off-by: @HilbertDavid
This commit is contained in:
mindspore-ci-bot 2021-05-17 17:28:28 +08:00 committed by Gitee
commit 6dd4881e63
17 changed files with 231 additions and 541 deletions

View File

@ -100,12 +100,7 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
model_ = mindspore::lite::Model::Import(ms_file_.c_str());
if (model_ == nullptr) {
std::cout << "import model failed" << std::endl;
return;
}
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context, true);
MS_ASSERT(nullptr != session_);
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_);
@ -172,7 +167,7 @@ int NetRunner::TrainLoop() {
mindspore::lite::LossMonitor lm(100);
mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_);
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
Rescaler rescale(255.0);
loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
@ -190,7 +185,7 @@ int NetRunner::Main() {
if (epochs_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
Model::Export(model_, trained_fn.c_str());
session_->Export(trained_fn);
}
return 0;
}

View File

@ -27,7 +27,6 @@
#include "include/train/accuracy_metrics.h"
#include "include/ms_tensor.h"
#include "include/dataset/datasets.h"
#include "include/model.h"
using mindspore::dataset::Dataset;
using mindspore::lite::AccuracyMetrics;
@ -37,7 +36,6 @@ class NetRunner {
int Main();
bool ReadArgs(int argc, char *argv[]);
~NetRunner();
mindspore::lite::Model *model_ = nullptr;
private:
void Usage();

View File

@ -112,6 +112,9 @@ if [ "${TARGET}" == "arm64" ]; then
echo "========Training on Device====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh"
echo "===Evaluating trained Model====="
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh"
echo
else
cd ${PACKAGE} || exit 1
echo "==Evaluating Untrained Model==="
@ -120,5 +123,8 @@ else
echo "======Training Locally========="
./train.sh
echo "===Evaluating trained Model====="
./eval.sh
cd ..
fi

View File

@ -187,7 +187,7 @@ int NetRunner::TrainLoop() {
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
auto cpkt_fn =
ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str());
session_->Export(cpkt_fn);
}
std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
@ -213,12 +213,12 @@ int NetRunner::Main() {
float acc = CalculateAccuracy(ds_.val_data(), session_);
std::cout << "accuracy on validation data = " << acc << std::endl;
if (cycles_ > 0 && head_model_ != nullptr) {
if (cycles_ > 0) {
auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms";
mindspore::lite::Model::Export(head_model_, trained_fn.c_str());
session_->Export(trained_fn);
}
if (!save_inference_.empty()) {
int status = session_->ExportInference(save_inference_);
int status = session_->Export(save_inference_, mindspore::lite::MT_INFERENCE);
if (status != mindspore::lite::RET_OK) {
std::cout << "Failed to save inference file";
return mindspore::lite::RET_ERROR;

View File

@ -44,8 +44,6 @@ class NetRunner {
DataSet ds_;
mindspore::session::TrainSession *session_ = nullptr;
mindspore::lite::Model *backbone_model_ = nullptr;
mindspore::lite::Model *head_model_ = nullptr;
std::string ms_backbone_file_ = "";
std::string ms_head_file_ = "";

View File

@ -29,8 +29,8 @@ namespace lite {
class CkptSaver : public session::TrainLoopCallBack {
public:
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
CkptSaver(int save_every_n, const std::string &filename_prefix)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
~CkptSaver() = default;
@ -38,7 +38,7 @@ class CkptSaver : public session::TrainLoopCallBack {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str());
Model::Export(model_, cpkt_fn.c_str());
cb_data.session_->Export(cpkt_fn);
}
return session::RET_CONTINUE;
}
@ -46,7 +46,6 @@ class CkptSaver : public session::TrainLoopCallBack {
private:
int save_every_n_;
std::string filename_prefix_;
mindspore::lite::Model *model_ = nullptr;
};
} // namespace lite

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_CFG_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_CFG_H_
#include <string>
namespace mindspore {
namespace lite {
/// \brief MixPrecisionCfg defined for holding mix precision training configuration.
class MixPrecisionCfg {
public:
MixPrecisionCfg() {
this->dynamic_loss_scale_ = false;
this->loss_scale_ = 128.0f;
this->keep_batchnorm_fp32_ = true;
this->num_of_not_nan_iter_th_ = 1000;
}
MixPrecisionCfg(const MixPrecisionCfg &rhs) {
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
this->loss_scale_ = rhs.loss_scale_;
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
}
MixPrecisionCfg &operator=(MixPrecisionCfg const &rhs) {
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
this->loss_scale_ = rhs.loss_scale_;
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
return *this;
}
bool dynamic_loss_scale_; /**< Enable\disable dynamic loss scale during mix precision training */
float loss_scale_; /**< Initial loss scale factor */
bool keep_batchnorm_fp32_; /**< Keep batch norm in FP32 while training */
int num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
};
/// \brief TrainCfg defined for holding train configuration.
class TrainCfg {
public:
TrainCfg() { this->loss_name_ = "_loss_fn"; }
TrainCfg(const TrainCfg &rhs) {
this->loss_name_ = rhs.loss_name_;
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
}
TrainCfg &operator=(const TrainCfg &rhs) {
this->loss_name_ = rhs.loss_name_;
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
return *this;
}
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
};
typedef enum {
FT_FLATBUFFER, // Flatbuffer format
FT_MIBDIR // MINDIR format
} FormatType;
typedef enum {
QT_DEFAULT, // the quantization of the original model will apply
QT_NONE, // apply no quantization
QT_WEIGHT // apply weight quantization
} QuantType;
typedef enum {
MT_TRAIN, // Both Train and Inference part of the compiled model are serialized
MT_INFERENCE // Only the Inference part of the compiled model is serialized
} ModelType;
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_CFG_H_

View File

@ -20,6 +20,7 @@
#include <tuple>
#include "include/lite_session.h"
#include "include/errorcode.h"
#include "include/train/train_cfg.h"
namespace mindspore {
namespace session {
@ -32,26 +33,14 @@ class TrainSession : public session::LiteSession {
/// \brief Static method to create a TrainSession object
///
/// \param[in] model A buffer that was read from a MS model file
/// \param[in] filename name of flatbuffer that holds the flatbuffer
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
/// \param[in] cfg training configuration, set to null for default configuration
///
/// \return Pointer of MindSpore Lite TrainSession
static TrainSession *CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode = false);
/// \brief Static method to create a transfer lernning support TrainSession object
///
/// \param[in] model_buf_backbone A buffer that was read from a backbone MS model file
/// \param[in] size_backbone Length of the backbone net buffer
/// \param[in] model_buf_head A buffer that was read from a head MS model file
/// \param[in] size_head Length of the head net buffer
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
///
/// \return Pointer of MindSpore Lite TrainSession
static TrainSession *CreateTransferSession(const char *model_buf_backbone, size_t size_backbone,
const char *model_buf_head, size_t size_head, lite::Context *context,
bool train_mode = false);
static TrainSession *CreateSession(const std::string &filename, const lite::Context *context, bool train_mode = false,
const lite::TrainCfg *cfg = nullptr);
/// \brief Static method to create a TrainSession object
///
@ -62,7 +51,8 @@ class TrainSession : public session::LiteSession {
///
/// \return Pointer of MindSpore Lite TrainSession
static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
lite::Context *context, bool train_mode = false);
const lite::Context *context, bool train_mode = false,
const lite::TrainCfg *cfg = nullptr);
/// \brief Set model to train mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
@ -108,25 +98,21 @@ class TrainSession : public session::LiteSession {
/// \return a vector of output tensors (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetPredictions() const = 0;
/// \brief Set part of the name that identify a loss kernel
/// \param[in] loss_name Identifucation name for loss kernels
/// \brief Save model
/// \param[in] file_name pretrained model file name prefix. '.ms' extenension is added if does not exist
/// \param[in] model_type indication whether to save full model or only the inference part
/// \param[in] quant_type indication whether to quantize exported model
/// \param[in] format of exported file (currently only FT_FLATBUFFER is supported)
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetLossName(std::string loss_name) {
loss_name_ = loss_name;
return mindspore::lite::RET_OK;
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
lite::QuantType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFER) {
return mindspore::lite::RET_ERROR;
}
/// \brief Save model for inference (LiteSession)
/// \param[in] fb_name pretrained model file name prefix. '.ms' is added as extension.
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int ExportInference(std::string fb_name) { return mindspore::lite::RET_ERROR; }
protected:
bool train_mode_ = false;
std::string get_loss_name() const { return loss_name_; }
private:
std::string loss_name_ = "_loss_fn";
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_

View File

@ -255,11 +255,12 @@ int TrainExport::AddTransformNode() {
int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors,
const std::vector<std::string> &output_names, const Model *model) {
const std::vector<std::string> &output_names, const Model *model, QuantType quant_type) {
std::vector<size_t> map_index;
std::set<size_t> out_set;
int offset = meta_graph_->allTensors.size();
int tensor_idx = offset;
quant_type_ = quant_type;
if (meta_graph_ == nullptr) {
int status = ExportInit(model->name_, model->version_);

View File

@ -23,6 +23,7 @@
#include "schema/inner/model_generated.h"
#include "src/lite_kernel.h"
#include "src/lite_model.h"
#include "include/train/train_cfg.h"
namespace mindspore {
#ifndef _STUB
@ -40,7 +41,7 @@ class TrainExport {
virtual ~TrainExport();
int ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names,
const Model *model);
const Model *model, QuantType quant_type);
int ExportInit(const std::string model_name, std::string version);
int SaveToFile();
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
@ -69,6 +70,7 @@ class TrainExport {
bool NeedQuantization(const mindspore::lite::Tensor *tensor);
virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor);
mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::LiteKernel *kernel);
QuantType quant_type_;
};
}; // namespace lite
} // namespace mindspore

View File

@ -55,6 +55,13 @@ TrainSession::TrainSession() {
}
}
int TrainSession::Init(const Context *context, const TrainCfg *train_cfg) {
if (train_cfg != nullptr) {
train_cfg_ = *train_cfg;
}
return lite::LiteSession::Init(context);
}
std::vector<CreatorOp> TrainSession::ReplaceOps() {
const std::vector<CreatorOp> replace = {
// currently no ops are Hijacked by TrainSession
@ -421,7 +428,7 @@ bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
kernel->type() == schema::PrimitiveType_SmoothL1LossGrad ||
kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits ||
kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad) ||
kernel->name().find(get_loss_name()) != std::string::npos;
kernel->name().find(train_cfg_.loss_name_) != std::string::npos;
}
bool TrainSession::IsGradKernel(const kernel::LiteKernel *kernel) const {
@ -442,19 +449,21 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
(kernel->type() == schema::PrimitiveType_FusedBatchNorm));
}
int TrainSession::SetLossName(std::string loss_name) {
session::TrainSession::SetLossName(loss_name);
CompileEvalOutputs();
CompileInferenceKernels();
if (IsEval()) {
output_node_map_ = eval_output_node_map_;
output_tensor_map_ = eval_output_tensor_map_;
output_tensor_names_ = eval_output_tensor_names_;
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantType quant_type, FormatType format) {
if (format != FT_FLATBUFFER) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR;
}
if (quant_type != QT_DEFAULT) {
MS_LOG(ERROR) << "Currently only QuantType default is supported";
return RET_ERROR;
}
if (model_type == MT_TRAIN) {
return lite::Model::Export(model_, file_name.c_str());
}
return RET_OK;
}
int TrainSession::ExportInference(std::string file_name) {
bool orig_train_state = IsTrain();
Eval();
TrainExport texport(file_name);
@ -463,7 +472,8 @@ int TrainSession::ExportInference(std::string file_name) {
MS_LOG(ERROR) << "cannot init export";
return status;
}
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_);
status = texport.ExportNet((model_type == MT_TRAIN) ? kernels_ : inference_kernels_, tensors_, GetOutputTensorNames(),
model_, quant_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot export Network";
return status;
@ -479,19 +489,25 @@ int TrainSession::ExportInference(std::string file_name) {
} // namespace lite
session::TrainSession *session::TrainSession::CreateSession(mindspore::lite::Model *model, lite::Context *context,
bool train_mode) {
session::TrainSession *session::TrainSession::CreateSession(const std::string &filename, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
auto session = new (std::nothrow) lite::TrainSession();
if (session == nullptr) {
delete model;
MS_LOG(ERROR) << "create session failed";
return nullptr;
}
auto ret = session->Init(context);
auto ret = session->Init(context, cfg);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
delete session;
delete model;
return nullptr;
}
auto *model = mindspore::lite::Model::Import(filename.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
delete session;
return nullptr;
}

View File

@ -54,12 +54,13 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
int CompileGraph(lite::Model *model) override;
virtual int CompileTrainGraph(lite::Model *model);
virtual int Init(const Context *context, const TrainCfg *train_cfg);
int Train() override;
int Eval() override;
int SetLearningRate(float learning_rate) override;
float GetLearningRate() override;
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;
int SetLossName(std::string loss_name) override;
void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); }
std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); }
@ -88,7 +89,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
}
return outputs;
}
int ExportInference(std::string file_name) override;
int Export(const std::string &fb_name, ModelType model_type, QuantType quant_type, FormatType) override;
protected:
void AllocWorkSpace();
@ -107,6 +108,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
virtual void CompileEvalOutputs();
Model *model_ = nullptr;
TrainCfg train_cfg_;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_;
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
std::vector<std::string> orig_output_tensor_names_;

View File

@ -40,7 +40,7 @@
namespace mindspore {
namespace lite {
TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context)
TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context)
: is_valid_(false) {
lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
size_backbone_ = size_backbone;
@ -179,10 +179,25 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
return map;
}
int TransferSession::ExportInference(std::string file_name) {
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantType quant_type,
FormatType format) {
if (format != FT_FLATBUFFER) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR;
}
if (quant_type != QT_DEFAULT) {
MS_LOG(ERROR) << "Currently only QuantType default is supported";
return RET_ERROR;
}
if (model_type == MT_TRAIN) {
return TrainSession::Export(filename, model_type, quant_type, format);
}
bool orig_train_state = IsTrain();
Eval();
TrainExport texport(file_name);
TrainExport texport(filename);
int status = texport.LoadModel(lite_model_, size_backbone_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot init export";
@ -197,14 +212,14 @@ int TransferSession::ExportInference(std::string file_name) {
return status;
}
}
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_);
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_, quant_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot serialize head";
return status;
}
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name;
MS_LOG(ERROR) << "failed to save to " << filename;
return status;
}
if (orig_train_state) Train();
@ -213,10 +228,10 @@ int TransferSession::ExportInference(std::string file_name) {
} // namespace lite
session::TrainSession *session::TrainSession::CreateTransferSession(const char *model_buf_backbone,
size_t size_backbone, const char *model_buf_head,
size_t size_head, lite::Context *context,
bool train_mode) {
static session::TrainSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
const char *model_buf_head, size_t size_head,
const lite::Context *context, bool train_mode,
const lite::TrainCfg *cfg) {
auto ValidModelSize = [](size_t size) -> bool {
constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
return size < MaxModelSize && size > 0;
@ -240,7 +255,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
return nullptr;
}
auto ret = session->Init(context);
auto ret = session->Init(context, cfg);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "init transfer session failed";
delete session;
@ -282,9 +297,10 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
session::TrainSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
const std::string &filename_head,
lite::Context *context, bool train_mode) {
size_t size_head = -1;
size_t size_backbone = -1;
const lite::Context *ctxt, bool train_mode,
const lite::TrainCfg *cfg) {
size_t size_head = 0;
size_t size_backbone = 0;
auto buf_head = lite::ReadFileToBuf(filename_head, &size_head);
if (buf_head == nullptr) {
return nullptr;
@ -293,8 +309,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const std::s
if (buf_backbone == nullptr) {
return nullptr;
}
return session::TrainSession::CreateTransferSession(buf_backbone.get(), size_backbone, buf_head.get(), size_head,
context, train_mode);
return CreateTransferSessionInt(buf_backbone.get(), size_backbone, buf_head.get(), size_head, ctxt, train_mode, cfg);
}
} // namespace mindspore

View File

@ -48,7 +48,7 @@ namespace lite {
class TransferSession : public lite::TrainSession {
public:
explicit TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context);
explicit TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context);
~TransferSession();
@ -61,7 +61,7 @@ class TransferSession : public lite::TrainSession {
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override;
int CompileTransferGraph();
int ExportInference(std::string file_name) override;
int Export(const std::string &fb_name, ModelType model_type, QuantType quant_type, FormatType) override;
protected:
lite::LiteSession *backbone_session_ = nullptr;

View File

@ -42,422 +42,6 @@ class NetworkTest : public mindspore::CommonTest {
int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out,
const char *tensor_name, bool debug = false);
// INPUT(0)
// V
// +-------------+
// | ReLU |
// +-------------+
// +---output(1) V
// | V V weights(2) <----+
// | +-------------+ |
// | | MatMul | |
// | +-------------+ |
// | output(3) V |
// | V V weights(4)<-+ |
// | +-------------+ | |
// | | Bias | | |
// | +-------------+ | |
// | output(5) V | |
// | V V LABELS(6) | |
// | +-------------+ | |
// | | CrossEntropy| | |
// | +-------------+ | |
// | +-dy(7) V V------------------------->Loss (14)
// | | V | |
// | | +-------------+ | |
// | | | BiasGrad | | |
// | | +-------------+ | |
// | | V db(8) | |
// | | +--------Update---+ |
// | +-------+ |
// +------V V |
// +-------------+ |
// | MatMul | |
// +-------------+ |
// V dw(9) |
// +-----------Update-----+
TEST_F(NetworkTest, tuning_layer) {
const int BATCH_SIZE = 32;
const int NUM_CLASSES = 10;
const int FEATURE_SIZE = 1000;
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
// define nodes
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {0};
node->outputIndex = {1};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Activation;
auto primitive = new schema::ActivationT;
ASSERT_NE(primitive, nullptr);
primitive->activation_type = schema::ActivationType_RELU;
node->primitive->value.value = primitive;
node->name = "ReLU";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {1, 2};
node->outputIndex = {3};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_MatMul;
auto primitive = new schema::MatMulT;
ASSERT_NE(primitive, nullptr);
primitive->transpose_a = false;
primitive->transpose_b = true;
node->primitive->value.value = primitive;
node->name = "MatMul1";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {3, 4};
node->outputIndex = {5};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_BiasAdd;
auto primitive = new schema::BiasAddT;
ASSERT_NE(primitive, nullptr);
node->primitive->value.value = primitive;
node->name = "BiasAdd";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {5, 6};
node->outputIndex = {14, 7};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_SoftmaxCrossEntropyWithLogits;
auto primitive = new schema::SoftmaxCrossEntropyWithLogitsT;
ASSERT_NE(primitive, nullptr);
node->primitive->value.value = primitive;
node->name = "SoftmaxCrossEntropyWithLogits";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {7};
node->outputIndex = {8};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_BiasAddGrad;
auto primitive = new schema::BiasAddGradT;
ASSERT_NE(primitive, nullptr);
node->primitive->value.value = primitive;
node->name = "BiasGrad";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {7, 1};
node->outputIndex = {9};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_MatMul;
auto primitive = new schema::MatMulT;
ASSERT_NE(primitive, nullptr);
primitive->transpose_a = true;
primitive->transpose_b = false;
node->primitive->value.value = primitive;
node->name = "MatMul2";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {2, 10, 11, 9, 12};
node->outputIndex = {};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_ApplyMomentum;
auto primitive = new schema::ApplyMomentumT;
ASSERT_NE(primitive, nullptr);
node->primitive->value.value = primitive;
node->name = "Momentum";
meta_graph->nodes.emplace_back(std::move(node));
}
{
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {4, 13, 11, 8, 12};
node->outputIndex = {};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_ApplyMomentum;
auto primitive = new schema::ApplyMomentumT;
ASSERT_NE(primitive, nullptr);
node->primitive->value.value = primitive;
node->name = "Momentum";
meta_graph->nodes.emplace_back(std::move(node));
}
meta_graph->inputIndex = {0, 6};
meta_graph->outputIndex = {5, 14};
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = lite::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {BATCH_SIZE, FEATURE_SIZE};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
// tensor 1 - relu
auto relu_out = std::make_unique<schema::TensorT>();
relu_out->nodeType = lite::NodeType_Parameter;
relu_out->format = schema::Format_NHWC;
relu_out->dataType = TypeId::kNumberTypeFloat32;
relu_out->dims = {BATCH_SIZE, FEATURE_SIZE};
relu_out->offset = -1;
meta_graph->allTensors.emplace_back(std::move(relu_out));
// tensor 2 - matmul weights
auto weight = std::make_unique<schema::TensorT>();
weight->nodeType = lite::NodeType_ValueNode;
weight->format = schema::Format_KHWC;
weight->dataType = TypeId::kNumberTypeFloat32;
weight->dims = {NUM_CLASSES, FEATURE_SIZE};
size_t weight_size;
char *buf;
std::string weight_path = "./test_data/train/train_weight_10_1000.bin";
ReadFile(weight_path.c_str(), &weight_size, &buf);
ASSERT_NE(nullptr, buf);
weight->data.resize(weight_size);
std::copy(buf, buf + weight_size, weight->data.data());
meta_graph->allTensors.emplace_back(std::move(weight));
delete[] buf;
// tensor 3 - matmul
auto input3 = std::make_unique<schema::TensorT>();
input3->nodeType = lite::NodeType_Parameter;
input3->format = schema::Format_NHWC;
input3->dataType = TypeId::kNumberTypeFloat32;
input3->dims = {BATCH_SIZE, NUM_CLASSES};
input3->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input3));
// tensor 4 - fc bias
auto bias = std::make_unique<schema::TensorT>();
bias->nodeType = lite::NodeType_ValueNode;
bias->format = schema::Format_NHWC;
bias->dataType = TypeId::kNumberTypeFloat32;
bias->dims = {NUM_CLASSES};
bias->offset = -1;
std::string bias_path = "./test_data/train/train_bias_10.bin";
size_t bias_size;
ReadFile(bias_path.c_str(), &bias_size, &buf);
ASSERT_NE(nullptr, buf);
bias->data.resize(bias_size);
std::copy(buf, buf + bias_size, bias->data.data());
meta_graph->allTensors.emplace_back(std::move(bias));
delete[] buf;
// tensor 5 - bias_add
auto input5 = std::make_unique<schema::TensorT>();
input5->nodeType = lite::NodeType_Parameter;
input5->format = schema::Format_NHWC;
input5->dataType = TypeId::kNumberTypeFloat32;
input5->dims = {BATCH_SIZE, NUM_CLASSES};
input5->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input5));
// tensor 6 - Label
{
auto label = std::make_unique<schema::TensorT>();
label->nodeType = lite::NodeType_ValueNode;
label->format = schema::Format_NHWC;
label->dataType = TypeId::kNumberTypeFloat32;
label->dims = {BATCH_SIZE * NUM_CLASSES};
label->offset = -1;
meta_graph->allTensors.emplace_back(std::move(label));
}
// tensor 7 - Softmaxentropy
auto input7 = std::make_unique<schema::TensorT>();
input7->nodeType = lite::NodeType_Parameter;
input7->format = schema::Format_NHWC;
input7->dataType = TypeId::kNumberTypeFloat32;
input7->dims = {BATCH_SIZE, NUM_CLASSES};
input7->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input7));
// tensor 8 - biasGrad
auto input8 = std::make_unique<schema::TensorT>();
input8->nodeType = lite::NodeType_Parameter;
input8->format = schema::Format_NHWC;
input8->dataType = TypeId::kNumberTypeFloat32;
input8->dims = {NUM_CLASSES};
input8->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input8));
// tensor 9 - matmul2
auto input9 = std::make_unique<schema::TensorT>();
input9->nodeType = lite::NodeType_Parameter;
input9->format = schema::Format_NHWC;
input9->dataType = TypeId::kNumberTypeFloat32;
input9->dims = {NUM_CLASSES, FEATURE_SIZE};
input9->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input9));
// tensor 10 weights accumulate
auto input10 = std::make_unique<schema::TensorT>();
input10->nodeType = lite::NodeType_ValueNode;
input10->format = schema::Format_NHWC;
input10->dataType = TypeId::kNumberTypeFloat32;
input10->dims = {NUM_CLASSES, FEATURE_SIZE};
input10->offset = -1;
size_t input10_size = NUM_CLASSES * FEATURE_SIZE * sizeof(float);
input10->data.resize(input10_size);
std::fill(input10->data.data(), input10->data.data() + input10_size, 0.f);
meta_graph->allTensors.emplace_back(std::move(input10));
// tensor 11 - lr
{
auto lr = std::make_unique<schema::TensorT>();
lr->nodeType = lite::NodeType_ValueNode;
lr->format = schema::Format_NHWC;
lr->dataType = TypeId::kNumberTypeFloat32;
lr->dims = {1};
lr->offset = -1;
lr->data.resize(sizeof(float));
float *data = reinterpret_cast<float *>(lr->data.data());
*data = 0.01f;
meta_graph->allTensors.emplace_back(std::move(lr));
}
// tensor 12 - momentum
{
auto input12 = std::make_unique<schema::TensorT>();
input12->nodeType = lite::NodeType_ValueNode;
input12->format = schema::Format_NHWC;
input12->dataType = TypeId::kNumberTypeFloat32;
input12->dims = {1};
input12->offset = -1;
input12->data.resize(sizeof(float));
float *data = reinterpret_cast<float *>(input12->data.data());
*data = 0.f;
meta_graph->allTensors.emplace_back(std::move(input12));
}
// tensor 13 - bias accumulate
auto input13 = std::make_unique<schema::TensorT>();
input13->nodeType = lite::NodeType_ValueNode;
input13->format = schema::Format_NHWC;
input13->dataType = TypeId::kNumberTypeFloat32;
input13->dims = {NUM_CLASSES};
input13->offset = -1;
size_t input13_size = NUM_CLASSES * sizeof(float);
input13->data.resize(input13_size);
std::fill(input13->data.data(), input13->data.data() + input13_size, 0.f);
meta_graph->allTensors.emplace_back(std::move(input13));
// tensor 14 - loss
{
auto loss14 = std::make_unique<schema::TensorT>();
loss14->nodeType = lite::NodeType_ValueNode;
loss14->format = schema::Format_NHWC;
loss14->dataType = TypeId::kNumberTypeFloat32;
loss14->dims = {1};
loss14->offset = -1;
loss14->data.resize(sizeof(float));
float *data = reinterpret_cast<float *>(loss14->data.data());
*data = 0.0f;
meta_graph->allTensors.emplace_back(std::move(loss14));
}
//================================================================
buf = nullptr;
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
std::cout << "build fb size= " << size << std::endl;
meta_graph.reset();
content = nullptr;
auto *model = mindspore::lite::Model::Import(content, size);
ASSERT_NE(nullptr, model);
lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context.thread_num_ = 1;
auto session = session::TrainSession::CreateSession(model, &context);
ASSERT_NE(nullptr, session);
session->Train();
session->Train(); // Just double check that calling Train twice does not cause a problem
auto inputs = session->GetInputs();
ASSERT_EQ(inputs.size(), 2);
auto inTensor = inputs.at(0);
ASSERT_NE(nullptr, inTensor);
auto data = inTensor->MutableData();
//===================================================
size_t input_size;
std::string input_path = "./test_data/train/train_input_32_1000.bin";
ReadFile(input_path.c_str(), &input_size, &buf);
ASSERT_NE(nullptr, buf);
auto input_data = reinterpret_cast<float *>(buf);
ASSERT_NE(nullptr, input_data);
//===================================================
ASSERT_EQ(input_size, inTensor->Size());
memcpy(data, input_data, input_size);
delete[] buf;
auto labelTensor = inputs.at(1);
ASSERT_NE(nullptr, labelTensor);
ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum());
auto labels = reinterpret_cast<float *>(labelTensor->MutableData());
std::fill(labels, labels + labelTensor->ElementsNum(), 0.f);
for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0;
auto ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputsByNodeName("SoftmaxCrossEntropyWithLogits");
ASSERT_EQ(outputs.size(), 1);
auto outTensor = (outputs.at(0));
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
auto *outData = reinterpret_cast<float *>(outTensor->MutableData());
ASSERT_NE(nullptr, outData);
std::cout << "==============Initial=Loss=====================" << std::endl;
std::cout << outData[0] << ", " << std::endl;
session->Eval();
session->Eval(); // Just double check that calling eval twice does not cause a problem
ret = session->RunGraph();
outputs = session->GetOutputsByNodeName("BiasAdd");
ASSERT_EQ(outputs.size(), 1);
outTensor = (outputs.at(0));
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
outData = reinterpret_cast<float *>(outTensor->MutableData());
ASSERT_NE(nullptr, outData);
std::cout << "==============Scores=after-single=train========" << std::endl;
for (int i = 0; i < 10; i++) {
std::cout << outData[i] << ", ";
}
std::cout << std::endl;
std::string output_path = "./test_data/train/train_output_32_10.bin";
auto error = RelativeOutputError(outData, output_path);
EXPECT_LT(error, 2e-3);
ret = session->RunGraph();
auto all_output_tensors = session->GetOutputs();
outTensor = (all_output_tensors["5"]);
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
outData = reinterpret_cast<float *>(outTensor->MutableData());
ASSERT_NE(nullptr, outData);
std::cout << "==============Scores=eval-second-time==========" << std::endl;
for (int i = 0; i < 10; i++) {
std::cout << outData[i] << ", ";
}
std::cout << std::endl;
error = RelativeOutputError(outData, output_path);
EXPECT_LT(error, 2e-3);
session->Train();
session->Eval(); // do some more zig-zags
ret = session->RunGraph();
outTensor = session->GetOutputByTensorName("5");
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
outData = reinterpret_cast<float *>(outTensor->MutableData());
ASSERT_NE(nullptr, outData);
std::cout << "==============Scores=Just Checking 3rd time====" << std::endl;
for (int i = 0; i < 10; i++) {
std::cout << outData[i] << ", ";
}
std::cout << std::endl;
error = RelativeOutputError(outData, output_path);
EXPECT_LT(error, 2e-3);
}
int32_t fileIterator(mindspore::session::TrainSession *session, const std::string &path,
std::function<int32_t(mindspore::session::TrainSession *session, const std::string &)> cb) {
int32_t res = 0;
@ -517,10 +101,7 @@ TEST_F(NetworkTest, efficient_net) {
context->thread_num_ = 1;
std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms";
auto *model = mindspore::lite::Model::Import(net.c_str());
ASSERT_NE(model, nullptr);
auto session = session::TrainSession::CreateSession(model, context, false);
auto session = session::TrainSession::CreateSession(net, context, false);
ASSERT_NE(session, nullptr);
std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin";
@ -560,15 +141,13 @@ TEST_F(NetworkTest, mobileface_net) {
delete context;
}
TEST_F(NetworkTest, setname) {
TEST_F(NetworkTest, noname) {
std::string net = "./test_data/nets/lenet_train.ms";
lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context.thread_num_ = 1;
auto *model = mindspore::lite::Model::Import(net.c_str());
ASSERT_NE(model, nullptr);
auto session = mindspore::session::TrainSession::CreateSession(model, &context);
auto session = mindspore::session::TrainSession::CreateSession(net, &context);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();
@ -577,20 +156,25 @@ TEST_F(NetworkTest, setname) {
EXPECT_EQ(tensors_map.begin()->first, "24");
EXPECT_EQ(tensor_names.size(), 1);
EXPECT_EQ(tensor_names.at(0), "Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op107");
delete session;
}
auto res = session->SetLossName("nhwc");
EXPECT_EQ(res, RET_OK);
tensors_map = session->GetOutputs();
tensor_names = session->GetOutputTensorNames();
TEST_F(NetworkTest, setname) {
std::string net = "./test_data/nets/lenet_train.ms";
lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
context.thread_num_ = 1;
lite::TrainCfg train_cfg;
train_cfg.loss_name_ = "nhwc";
auto session = mindspore::session::TrainSession::CreateSession(net, &context, true, &train_cfg);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();
auto tensor_names = session->GetOutputTensorNames();
EXPECT_EQ(tensors_map.begin()->first, "8");
EXPECT_EQ(tensor_names.at(0), "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/MaxPool-op88");
res = session->SetLossName("loss");
EXPECT_EQ(res, RET_OK);
tensors_map = session->GetOutputs();
tensor_names = session->GetOutputTensorNames();
EXPECT_EQ(tensors_map.begin()->first, "24");
EXPECT_EQ(tensor_names.at(0), "Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op107");
delete session;
}

View File

@ -317,33 +317,35 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, int train_session
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = flags_->num_threads_;
MS_LOG(INFO) << "start reading model file" << filename.c_str();
std::cout << "start reading model file " << filename.c_str() << std::endl;
auto *model = mindspore::lite::Model::Import(filename.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return RET_ERROR;
TrainCfg train_cfg;
if (flags_->loss_name_ != "") {
train_cfg.loss_name_ = flags_->loss_name_;
}
session::LiteSession *session = nullptr;
session::TrainSession *t_session = nullptr;
if (train_session) {
t_session = session::TrainSession::CreateSession(model, &context);
MS_LOG(INFO) << "CreateSession from model file" << filename.c_str();
std::cout << "CreateSession from model file " << filename.c_str() << std::endl;
t_session = session::TrainSession::CreateSession(filename, &context, true, &train_cfg);
if (t_session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl;
delete model;
return RET_ERROR;
}
if (flags_->loss_name_ != "") {
t_session->SetLossName(flags_->loss_name_);
}
if (epochs > 0) {
t_session->Train();
}
session = t_session;
} else {
MS_LOG(INFO) << "start reading model file" << filename.c_str();
std::cout << "start reading model file " << filename.c_str() << std::endl;
auto *model = mindspore::lite::Model::Import(filename.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "create model for train session failed";
return RET_ERROR;
}
session = session::LiteSession::CreateSession(&context);
if (session == nullptr) {
MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
@ -378,7 +380,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, int train_session
std::cout << "Run MarkPerformance error: " << status << std::endl;
return status;
}
SaveModels(t_session, model); // save file if flags are on
SaveModels(t_session); // save file if flags are on
}
if (!flags_->data_file_.empty()) {
if (t_session != nullptr) {
@ -406,10 +408,10 @@ int NetTrain::RunNetTrain() {
return RET_OK;
}
int NetTrain::SaveModels(session::TrainSession *session, mindspore::lite::Model *model) {
int NetTrain::SaveModels(session::TrainSession *session) {
if (!flags_->export_file_.empty()) {
auto ret = Model::Export(model, flags_->export_file_.c_str());
if (ret != RET_OK) {
auto status = session->Export(flags_->export_file_);
if (status != RET_OK) {
MS_LOG(ERROR) << "SaveToFile error";
std::cout << "Run SaveToFile error";
return RET_ERROR;
@ -417,7 +419,7 @@ int NetTrain::SaveModels(session::TrainSession *session, mindspore::lite::Model
}
if (!flags_->inference_file_.empty()) {
auto tick = GetTimeUs();
auto status = session->ExportInference(flags_->inference_file_);
auto status = session->Export(flags_->inference_file_, lite::MT_INFERENCE);
if (status != RET_OK) {
MS_LOG(ERROR) << "Save model error: " << status;
std::cout << "Save model error: " << status << std::endl;

View File

@ -196,7 +196,7 @@ class MS_API NetTrain {
int MarkAccuracy(session::LiteSession *lite_session);
int CompareOutput(const session::LiteSession &lite_session);
int SaveModels(session::TrainSession *session, mindspore::lite::Model *model);
int SaveModels(session::TrainSession *session);
int CheckExecutionOfSavedModels();
NetTrainFlags *flags_;