forked from mindspore-Ecosystem/mindspore
!16387 [MS][LITE][TOD] Revert to previous TrainSession API (with additional Config Parameter)
From: @ehaleva Reviewed-by: @HilbertDavid,@hangangqiang Signed-off-by: @HilbertDavid
This commit is contained in:
commit
6dd4881e63
|
@ -100,12 +100,7 @@ void NetRunner::InitAndFigureInputs() {
|
|||
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
|
||||
context.thread_num_ = 2;
|
||||
|
||||
model_ = mindspore::lite::Model::Import(ms_file_.c_str());
|
||||
if (model_ == nullptr) {
|
||||
std::cout << "import model failed" << std::endl;
|
||||
return;
|
||||
}
|
||||
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);
|
||||
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context, true);
|
||||
|
||||
MS_ASSERT(nullptr != session_);
|
||||
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_);
|
||||
|
@ -172,7 +167,7 @@ int NetRunner::TrainLoop() {
|
|||
|
||||
mindspore::lite::LossMonitor lm(100);
|
||||
mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
|
||||
mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_);
|
||||
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
|
||||
Rescaler rescale(255.0);
|
||||
|
||||
loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
|
||||
|
@ -190,7 +185,7 @@ int NetRunner::Main() {
|
|||
|
||||
if (epochs_ > 0) {
|
||||
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
|
||||
Model::Export(model_, trained_fn.c_str());
|
||||
session_->Export(trained_fn);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "include/train/accuracy_metrics.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/dataset/datasets.h"
|
||||
#include "include/model.h"
|
||||
|
||||
using mindspore::dataset::Dataset;
|
||||
using mindspore::lite::AccuracyMetrics;
|
||||
|
@ -37,7 +36,6 @@ class NetRunner {
|
|||
int Main();
|
||||
bool ReadArgs(int argc, char *argv[]);
|
||||
~NetRunner();
|
||||
mindspore::lite::Model *model_ = nullptr;
|
||||
|
||||
private:
|
||||
void Usage();
|
||||
|
|
|
@ -112,6 +112,9 @@ if [ "${TARGET}" == "arm64" ]; then
|
|||
echo "========Training on Device====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh train.sh"
|
||||
|
||||
echo "===Evaluating trained Model====="
|
||||
adb shell "cd /data/local/tmp/package-arm64 && /system/bin/sh eval.sh"
|
||||
echo
|
||||
else
|
||||
cd ${PACKAGE} || exit 1
|
||||
echo "==Evaluating Untrained Model==="
|
||||
|
@ -120,5 +123,8 @@ else
|
|||
echo "======Training Locally========="
|
||||
./train.sh
|
||||
|
||||
echo "===Evaluating trained Model====="
|
||||
./eval.sh
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
|
|
@ -187,7 +187,7 @@ int NetRunner::TrainLoop() {
|
|||
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
|
||||
auto cpkt_fn =
|
||||
ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
|
||||
mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str());
|
||||
session_->Export(cpkt_fn);
|
||||
}
|
||||
|
||||
std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
|
||||
|
@ -213,12 +213,12 @@ int NetRunner::Main() {
|
|||
float acc = CalculateAccuracy(ds_.val_data(), session_);
|
||||
std::cout << "accuracy on validation data = " << acc << std::endl;
|
||||
|
||||
if (cycles_ > 0 && head_model_ != nullptr) {
|
||||
if (cycles_ > 0) {
|
||||
auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms";
|
||||
mindspore::lite::Model::Export(head_model_, trained_fn.c_str());
|
||||
session_->Export(trained_fn);
|
||||
}
|
||||
if (!save_inference_.empty()) {
|
||||
int status = session_->ExportInference(save_inference_);
|
||||
int status = session_->Export(save_inference_, mindspore::lite::MT_INFERENCE);
|
||||
if (status != mindspore::lite::RET_OK) {
|
||||
std::cout << "Failed to save inference file";
|
||||
return mindspore::lite::RET_ERROR;
|
||||
|
|
|
@ -44,8 +44,6 @@ class NetRunner {
|
|||
|
||||
DataSet ds_;
|
||||
mindspore::session::TrainSession *session_ = nullptr;
|
||||
mindspore::lite::Model *backbone_model_ = nullptr;
|
||||
mindspore::lite::Model *head_model_ = nullptr;
|
||||
|
||||
std::string ms_backbone_file_ = "";
|
||||
std::string ms_head_file_ = "";
|
||||
|
|
|
@ -29,8 +29,8 @@ namespace lite {
|
|||
|
||||
class CkptSaver : public session::TrainLoopCallBack {
|
||||
public:
|
||||
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
|
||||
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
|
||||
CkptSaver(int save_every_n, const std::string &filename_prefix)
|
||||
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
|
||||
|
||||
~CkptSaver() = default;
|
||||
|
||||
|
@ -38,7 +38,7 @@ class CkptSaver : public session::TrainLoopCallBack {
|
|||
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
|
||||
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
|
||||
remove(cpkt_fn.c_str());
|
||||
Model::Export(model_, cpkt_fn.c_str());
|
||||
cb_data.session_->Export(cpkt_fn);
|
||||
}
|
||||
return session::RET_CONTINUE;
|
||||
}
|
||||
|
@ -46,7 +46,6 @@ class CkptSaver : public session::TrainLoopCallBack {
|
|||
private:
|
||||
int save_every_n_;
|
||||
std::string filename_prefix_;
|
||||
mindspore::lite::Model *model_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_CFG_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_CFG_H_
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
/// \brief MixPrecisionCfg defined for holding mix precision training configuration.
|
||||
class MixPrecisionCfg {
|
||||
public:
|
||||
MixPrecisionCfg() {
|
||||
this->dynamic_loss_scale_ = false;
|
||||
this->loss_scale_ = 128.0f;
|
||||
this->keep_batchnorm_fp32_ = true;
|
||||
this->num_of_not_nan_iter_th_ = 1000;
|
||||
}
|
||||
MixPrecisionCfg(const MixPrecisionCfg &rhs) {
|
||||
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
|
||||
this->loss_scale_ = rhs.loss_scale_;
|
||||
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
|
||||
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
|
||||
}
|
||||
MixPrecisionCfg &operator=(MixPrecisionCfg const &rhs) {
|
||||
this->dynamic_loss_scale_ = rhs.dynamic_loss_scale_;
|
||||
this->loss_scale_ = rhs.loss_scale_;
|
||||
this->keep_batchnorm_fp32_ = rhs.keep_batchnorm_fp32_;
|
||||
this->num_of_not_nan_iter_th_ = rhs.num_of_not_nan_iter_th_;
|
||||
return *this;
|
||||
}
|
||||
bool dynamic_loss_scale_; /**< Enable\disable dynamic loss scale during mix precision training */
|
||||
float loss_scale_; /**< Initial loss scale factor */
|
||||
bool keep_batchnorm_fp32_; /**< Keep batch norm in FP32 while training */
|
||||
int num_of_not_nan_iter_th_; /**< a threshold for modifying loss scale when dynamic loss scale is enabled */
|
||||
};
|
||||
|
||||
/// \brief TrainCfg defined for holding train configuration.
|
||||
class TrainCfg {
|
||||
public:
|
||||
TrainCfg() { this->loss_name_ = "_loss_fn"; }
|
||||
TrainCfg(const TrainCfg &rhs) {
|
||||
this->loss_name_ = rhs.loss_name_;
|
||||
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
||||
}
|
||||
TrainCfg &operator=(const TrainCfg &rhs) {
|
||||
this->loss_name_ = rhs.loss_name_;
|
||||
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
||||
return *this;
|
||||
}
|
||||
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
|
||||
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
||||
};
|
||||
|
||||
typedef enum {
|
||||
FT_FLATBUFFER, // Flatbuffer format
|
||||
FT_MIBDIR // MINDIR format
|
||||
} FormatType;
|
||||
|
||||
typedef enum {
|
||||
QT_DEFAULT, // the quantization of the original model will apply
|
||||
QT_NONE, // apply no quantization
|
||||
QT_WEIGHT // apply weight quantization
|
||||
} QuantType;
|
||||
|
||||
typedef enum {
|
||||
MT_TRAIN, // Both Train and Inference part of the compiled model are serialized
|
||||
MT_INFERENCE // Only the Inference part of the compiled model is serialized
|
||||
} ModelType;
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_CFG_H_
|
|
@ -20,6 +20,7 @@
|
|||
#include <tuple>
|
||||
#include "include/lite_session.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -32,26 +33,14 @@ class TrainSession : public session::LiteSession {
|
|||
|
||||
/// \brief Static method to create a TrainSession object
|
||||
///
|
||||
/// \param[in] model A buffer that was read from a MS model file
|
||||
/// \param[in] filename name of flatbuffer that holds the flatbuffer
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
/// \param[in] cfg training configuration, set to null for default configuration
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite TrainSession
|
||||
static TrainSession *CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode = false);
|
||||
|
||||
/// \brief Static method to create a transfer lernning support TrainSession object
|
||||
///
|
||||
/// \param[in] model_buf_backbone A buffer that was read from a backbone MS model file
|
||||
/// \param[in] size_backbone Length of the backbone net buffer
|
||||
/// \param[in] model_buf_head A buffer that was read from a head MS model file
|
||||
/// \param[in] size_head Length of the head net buffer
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite TrainSession
|
||||
static TrainSession *CreateTransferSession(const char *model_buf_backbone, size_t size_backbone,
|
||||
const char *model_buf_head, size_t size_head, lite::Context *context,
|
||||
bool train_mode = false);
|
||||
static TrainSession *CreateSession(const std::string &filename, const lite::Context *context, bool train_mode = false,
|
||||
const lite::TrainCfg *cfg = nullptr);
|
||||
|
||||
/// \brief Static method to create a TrainSession object
|
||||
///
|
||||
|
@ -62,7 +51,8 @@ class TrainSession : public session::LiteSession {
|
|||
///
|
||||
/// \return Pointer of MindSpore Lite TrainSession
|
||||
static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
|
||||
lite::Context *context, bool train_mode = false);
|
||||
const lite::Context *context, bool train_mode = false,
|
||||
const lite::TrainCfg *cfg = nullptr);
|
||||
|
||||
/// \brief Set model to train mode
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
|
||||
|
@ -108,25 +98,21 @@ class TrainSession : public session::LiteSession {
|
|||
/// \return a vector of output tensors (MindSpore Lite MSTensor).
|
||||
virtual std::vector<tensor::MSTensor *> GetPredictions() const = 0;
|
||||
|
||||
/// \brief Set part of the name that identify a loss kernel
|
||||
/// \param[in] loss_name Identifucation name for loss kernels
|
||||
/// \brief Save model
|
||||
/// \param[in] file_name pretrained model file name prefix. '.ms' extenension is added if does not exist
|
||||
/// \param[in] model_type indication whether to save full model or only the inference part
|
||||
/// \param[in] quant_type indication whether to quantize exported model
|
||||
/// \param[in] format of exported file (currently only FT_FLATBUFFER is supported)
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int SetLossName(std::string loss_name) {
|
||||
loss_name_ = loss_name;
|
||||
return mindspore::lite::RET_OK;
|
||||
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
|
||||
lite::QuantType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFER) {
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
/// \brief Save model for inference (LiteSession)
|
||||
/// \param[in] fb_name pretrained model file name prefix. '.ms' is added as extension.
|
||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||
virtual int ExportInference(std::string fb_name) { return mindspore::lite::RET_ERROR; }
|
||||
|
||||
protected:
|
||||
bool train_mode_ = false;
|
||||
std::string get_loss_name() const { return loss_name_; }
|
||||
|
||||
private:
|
||||
std::string loss_name_ = "_loss_fn";
|
||||
};
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
||||
|
|
|
@ -255,11 +255,12 @@ int TrainExport::AddTransformNode() {
|
|||
|
||||
int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
|
||||
const std::vector<mindspore::lite::Tensor *> &tensors,
|
||||
const std::vector<std::string> &output_names, const Model *model) {
|
||||
const std::vector<std::string> &output_names, const Model *model, QuantType quant_type) {
|
||||
std::vector<size_t> map_index;
|
||||
std::set<size_t> out_set;
|
||||
int offset = meta_graph_->allTensors.size();
|
||||
int tensor_idx = offset;
|
||||
quant_type_ = quant_type;
|
||||
|
||||
if (meta_graph_ == nullptr) {
|
||||
int status = ExportInit(model->name_, model->version_);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/lite_model.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
|
||||
namespace mindspore {
|
||||
#ifndef _STUB
|
||||
|
@ -40,7 +41,7 @@ class TrainExport {
|
|||
virtual ~TrainExport();
|
||||
int ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
|
||||
const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names,
|
||||
const Model *model);
|
||||
const Model *model, QuantType quant_type);
|
||||
int ExportInit(const std::string model_name, std::string version);
|
||||
int SaveToFile();
|
||||
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
|
||||
|
@ -69,6 +70,7 @@ class TrainExport {
|
|||
bool NeedQuantization(const mindspore::lite::Tensor *tensor);
|
||||
virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor);
|
||||
mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::LiteKernel *kernel);
|
||||
QuantType quant_type_;
|
||||
};
|
||||
}; // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,13 @@ TrainSession::TrainSession() {
|
|||
}
|
||||
}
|
||||
|
||||
int TrainSession::Init(const Context *context, const TrainCfg *train_cfg) {
|
||||
if (train_cfg != nullptr) {
|
||||
train_cfg_ = *train_cfg;
|
||||
}
|
||||
return lite::LiteSession::Init(context);
|
||||
}
|
||||
|
||||
std::vector<CreatorOp> TrainSession::ReplaceOps() {
|
||||
const std::vector<CreatorOp> replace = {
|
||||
// currently no ops are Hijacked by TrainSession
|
||||
|
@ -421,7 +428,7 @@ bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
|
|||
kernel->type() == schema::PrimitiveType_SmoothL1LossGrad ||
|
||||
kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits ||
|
||||
kernel->type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad) ||
|
||||
kernel->name().find(get_loss_name()) != std::string::npos;
|
||||
kernel->name().find(train_cfg_.loss_name_) != std::string::npos;
|
||||
}
|
||||
|
||||
bool TrainSession::IsGradKernel(const kernel::LiteKernel *kernel) const {
|
||||
|
@ -442,19 +449,21 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
|
|||
(kernel->type() == schema::PrimitiveType_FusedBatchNorm));
|
||||
}
|
||||
|
||||
int TrainSession::SetLossName(std::string loss_name) {
|
||||
session::TrainSession::SetLossName(loss_name);
|
||||
CompileEvalOutputs();
|
||||
CompileInferenceKernels();
|
||||
if (IsEval()) {
|
||||
output_node_map_ = eval_output_node_map_;
|
||||
output_tensor_map_ = eval_output_tensor_map_;
|
||||
output_tensor_names_ = eval_output_tensor_names_;
|
||||
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantType quant_type, FormatType format) {
|
||||
if (format != FT_FLATBUFFER) {
|
||||
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (quant_type != QT_DEFAULT) {
|
||||
MS_LOG(ERROR) << "Currently only QuantType default is supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (model_type == MT_TRAIN) {
|
||||
return lite::Model::Export(model_, file_name.c_str());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainSession::ExportInference(std::string file_name) {
|
||||
bool orig_train_state = IsTrain();
|
||||
Eval();
|
||||
TrainExport texport(file_name);
|
||||
|
@ -463,7 +472,8 @@ int TrainSession::ExportInference(std::string file_name) {
|
|||
MS_LOG(ERROR) << "cannot init export";
|
||||
return status;
|
||||
}
|
||||
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_);
|
||||
status = texport.ExportNet((model_type == MT_TRAIN) ? kernels_ : inference_kernels_, tensors_, GetOutputTensorNames(),
|
||||
model_, quant_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot export Network";
|
||||
return status;
|
||||
|
@ -479,19 +489,25 @@ int TrainSession::ExportInference(std::string file_name) {
|
|||
|
||||
} // namespace lite
|
||||
|
||||
session::TrainSession *session::TrainSession::CreateSession(mindspore::lite::Model *model, lite::Context *context,
|
||||
bool train_mode) {
|
||||
session::TrainSession *session::TrainSession::CreateSession(const std::string &filename, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
auto session = new (std::nothrow) lite::TrainSession();
|
||||
if (session == nullptr) {
|
||||
delete model;
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = session->Init(context);
|
||||
|
||||
auto ret = session->Init(context, cfg);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init session failed";
|
||||
delete session;
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *model = mindspore::lite::Model::Import(filename.c_str());
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "create model for train session failed";
|
||||
delete session;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -54,12 +54,13 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
|
|||
int CompileGraph(lite::Model *model) override;
|
||||
virtual int CompileTrainGraph(lite::Model *model);
|
||||
|
||||
virtual int Init(const Context *context, const TrainCfg *train_cfg);
|
||||
|
||||
int Train() override;
|
||||
int Eval() override;
|
||||
int SetLearningRate(float learning_rate) override;
|
||||
float GetLearningRate() override;
|
||||
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;
|
||||
int SetLossName(std::string loss_name) override;
|
||||
|
||||
void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); }
|
||||
std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); }
|
||||
|
@ -88,7 +89,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
|
|||
}
|
||||
return outputs;
|
||||
}
|
||||
int ExportInference(std::string file_name) override;
|
||||
int Export(const std::string &fb_name, ModelType model_type, QuantType quant_type, FormatType) override;
|
||||
|
||||
protected:
|
||||
void AllocWorkSpace();
|
||||
|
@ -107,6 +108,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
|
|||
virtual void CompileEvalOutputs();
|
||||
|
||||
Model *model_ = nullptr;
|
||||
TrainCfg train_cfg_;
|
||||
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_node_map_;
|
||||
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
|
||||
std::vector<std::string> orig_output_tensor_names_;
|
||||
|
|
|
@ -40,7 +40,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context)
|
||||
TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context)
|
||||
: is_valid_(false) {
|
||||
lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
|
||||
size_backbone_ = size_backbone;
|
||||
|
@ -179,10 +179,25 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
|
|||
return map;
|
||||
}
|
||||
|
||||
int TransferSession::ExportInference(std::string file_name) {
|
||||
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantType quant_type,
|
||||
FormatType format) {
|
||||
if (format != FT_FLATBUFFER) {
|
||||
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (quant_type != QT_DEFAULT) {
|
||||
MS_LOG(ERROR) << "Currently only QuantType default is supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (model_type == MT_TRAIN) {
|
||||
return TrainSession::Export(filename, model_type, quant_type, format);
|
||||
}
|
||||
|
||||
bool orig_train_state = IsTrain();
|
||||
Eval();
|
||||
TrainExport texport(file_name);
|
||||
TrainExport texport(filename);
|
||||
int status = texport.LoadModel(lite_model_, size_backbone_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot init export";
|
||||
|
@ -197,14 +212,14 @@ int TransferSession::ExportInference(std::string file_name) {
|
|||
return status;
|
||||
}
|
||||
}
|
||||
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_);
|
||||
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_, quant_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot serialize head";
|
||||
return status;
|
||||
}
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << file_name;
|
||||
MS_LOG(ERROR) << "failed to save to " << filename;
|
||||
return status;
|
||||
}
|
||||
if (orig_train_state) Train();
|
||||
|
@ -213,10 +228,10 @@ int TransferSession::ExportInference(std::string file_name) {
|
|||
|
||||
} // namespace lite
|
||||
|
||||
session::TrainSession *session::TrainSession::CreateTransferSession(const char *model_buf_backbone,
|
||||
size_t size_backbone, const char *model_buf_head,
|
||||
size_t size_head, lite::Context *context,
|
||||
bool train_mode) {
|
||||
static session::TrainSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
|
||||
const char *model_buf_head, size_t size_head,
|
||||
const lite::Context *context, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
auto ValidModelSize = [](size_t size) -> bool {
|
||||
constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
|
||||
return size < MaxModelSize && size > 0;
|
||||
|
@ -240,7 +255,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = session->Init(context);
|
||||
auto ret = session->Init(context, cfg);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init transfer session failed";
|
||||
delete session;
|
||||
|
@ -282,9 +297,10 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
|
|||
|
||||
session::TrainSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
|
||||
const std::string &filename_head,
|
||||
lite::Context *context, bool train_mode) {
|
||||
size_t size_head = -1;
|
||||
size_t size_backbone = -1;
|
||||
const lite::Context *ctxt, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
size_t size_head = 0;
|
||||
size_t size_backbone = 0;
|
||||
auto buf_head = lite::ReadFileToBuf(filename_head, &size_head);
|
||||
if (buf_head == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -293,8 +309,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const std::s
|
|||
if (buf_backbone == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return session::TrainSession::CreateTransferSession(buf_backbone.get(), size_backbone, buf_head.get(), size_head,
|
||||
context, train_mode);
|
||||
return CreateTransferSessionInt(buf_backbone.get(), size_backbone, buf_head.get(), size_head, ctxt, train_mode, cfg);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,7 +48,7 @@ namespace lite {
|
|||
|
||||
class TransferSession : public lite::TrainSession {
|
||||
public:
|
||||
explicit TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context);
|
||||
explicit TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context);
|
||||
|
||||
~TransferSession();
|
||||
|
||||
|
@ -61,7 +61,7 @@ class TransferSession : public lite::TrainSession {
|
|||
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override;
|
||||
|
||||
int CompileTransferGraph();
|
||||
int ExportInference(std::string file_name) override;
|
||||
int Export(const std::string &fb_name, ModelType model_type, QuantType quant_type, FormatType) override;
|
||||
|
||||
protected:
|
||||
lite::LiteSession *backbone_session_ = nullptr;
|
||||
|
|
|
@ -42,422 +42,6 @@ class NetworkTest : public mindspore::CommonTest {
|
|||
int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out,
|
||||
const char *tensor_name, bool debug = false);
|
||||
|
||||
// INPUT(0)
|
||||
// V
|
||||
// +-------------+
|
||||
// | ReLU |
|
||||
// +-------------+
|
||||
// +---output(1) V
|
||||
// | V V weights(2) <----+
|
||||
// | +-------------+ |
|
||||
// | | MatMul | |
|
||||
// | +-------------+ |
|
||||
// | output(3) V |
|
||||
// | V V weights(4)<-+ |
|
||||
// | +-------------+ | |
|
||||
// | | Bias | | |
|
||||
// | +-------------+ | |
|
||||
// | output(5) V | |
|
||||
// | V V LABELS(6) | |
|
||||
// | +-------------+ | |
|
||||
// | | CrossEntropy| | |
|
||||
// | +-------------+ | |
|
||||
// | +-dy(7) V V------------------------->Loss (14)
|
||||
// | | V | |
|
||||
// | | +-------------+ | |
|
||||
// | | | BiasGrad | | |
|
||||
// | | +-------------+ | |
|
||||
// | | V db(8) | |
|
||||
// | | +--------Update---+ |
|
||||
// | +-------+ |
|
||||
// +------V V |
|
||||
// +-------------+ |
|
||||
// | MatMul | |
|
||||
// +-------------+ |
|
||||
// V dw(9) |
|
||||
// +-----------Update-----+
|
||||
TEST_F(NetworkTest, tuning_layer) {
|
||||
const int BATCH_SIZE = 32;
|
||||
const int NUM_CLASSES = 10;
|
||||
const int FEATURE_SIZE = 1000;
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// define nodes
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {0};
|
||||
node->outputIndex = {1};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
auto primitive = new schema::ActivationT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
primitive->activation_type = schema::ActivationType_RELU;
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "ReLU";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {1, 2};
|
||||
node->outputIndex = {3};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
auto primitive = new schema::MatMulT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
primitive->transpose_a = false;
|
||||
primitive->transpose_b = true;
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "MatMul1";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {3, 4};
|
||||
node->outputIndex = {5};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_BiasAdd;
|
||||
auto primitive = new schema::BiasAddT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "BiasAdd";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {5, 6};
|
||||
node->outputIndex = {14, 7};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_SoftmaxCrossEntropyWithLogits;
|
||||
auto primitive = new schema::SoftmaxCrossEntropyWithLogitsT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "SoftmaxCrossEntropyWithLogits";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {7};
|
||||
node->outputIndex = {8};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_BiasAddGrad;
|
||||
auto primitive = new schema::BiasAddGradT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "BiasGrad";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {7, 1};
|
||||
node->outputIndex = {9};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
auto primitive = new schema::MatMulT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
primitive->transpose_a = true;
|
||||
primitive->transpose_b = false;
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "MatMul2";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {2, 10, 11, 9, 12};
|
||||
node->outputIndex = {};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_ApplyMomentum;
|
||||
auto primitive = new schema::ApplyMomentumT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "Momentum";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
{
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {4, 13, 11, 8, 12};
|
||||
node->outputIndex = {};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_ApplyMomentum;
|
||||
auto primitive = new schema::ApplyMomentumT;
|
||||
ASSERT_NE(primitive, nullptr);
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "Momentum";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
meta_graph->inputIndex = {0, 6};
|
||||
meta_graph->outputIndex = {5, 14};
|
||||
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = lite::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {BATCH_SIZE, FEATURE_SIZE};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
// tensor 1 - relu
|
||||
auto relu_out = std::make_unique<schema::TensorT>();
|
||||
relu_out->nodeType = lite::NodeType_Parameter;
|
||||
relu_out->format = schema::Format_NHWC;
|
||||
relu_out->dataType = TypeId::kNumberTypeFloat32;
|
||||
relu_out->dims = {BATCH_SIZE, FEATURE_SIZE};
|
||||
relu_out->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(relu_out));
|
||||
// tensor 2 - matmul weights
|
||||
auto weight = std::make_unique<schema::TensorT>();
|
||||
weight->nodeType = lite::NodeType_ValueNode;
|
||||
weight->format = schema::Format_KHWC;
|
||||
weight->dataType = TypeId::kNumberTypeFloat32;
|
||||
weight->dims = {NUM_CLASSES, FEATURE_SIZE};
|
||||
size_t weight_size;
|
||||
char *buf;
|
||||
std::string weight_path = "./test_data/train/train_weight_10_1000.bin";
|
||||
ReadFile(weight_path.c_str(), &weight_size, &buf);
|
||||
ASSERT_NE(nullptr, buf);
|
||||
weight->data.resize(weight_size);
|
||||
std::copy(buf, buf + weight_size, weight->data.data());
|
||||
meta_graph->allTensors.emplace_back(std::move(weight));
|
||||
delete[] buf;
|
||||
// tensor 3 - matmul
|
||||
auto input3 = std::make_unique<schema::TensorT>();
|
||||
input3->nodeType = lite::NodeType_Parameter;
|
||||
input3->format = schema::Format_NHWC;
|
||||
input3->dataType = TypeId::kNumberTypeFloat32;
|
||||
input3->dims = {BATCH_SIZE, NUM_CLASSES};
|
||||
input3->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input3));
|
||||
// tensor 4 - fc bias
|
||||
auto bias = std::make_unique<schema::TensorT>();
|
||||
bias->nodeType = lite::NodeType_ValueNode;
|
||||
bias->format = schema::Format_NHWC;
|
||||
bias->dataType = TypeId::kNumberTypeFloat32;
|
||||
bias->dims = {NUM_CLASSES};
|
||||
bias->offset = -1;
|
||||
std::string bias_path = "./test_data/train/train_bias_10.bin";
|
||||
size_t bias_size;
|
||||
ReadFile(bias_path.c_str(), &bias_size, &buf);
|
||||
ASSERT_NE(nullptr, buf);
|
||||
bias->data.resize(bias_size);
|
||||
std::copy(buf, buf + bias_size, bias->data.data());
|
||||
meta_graph->allTensors.emplace_back(std::move(bias));
|
||||
delete[] buf;
|
||||
|
||||
// tensor 5 - bias_add
|
||||
auto input5 = std::make_unique<schema::TensorT>();
|
||||
input5->nodeType = lite::NodeType_Parameter;
|
||||
input5->format = schema::Format_NHWC;
|
||||
input5->dataType = TypeId::kNumberTypeFloat32;
|
||||
input5->dims = {BATCH_SIZE, NUM_CLASSES};
|
||||
input5->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input5));
|
||||
// tensor 6 - Label
|
||||
{
|
||||
auto label = std::make_unique<schema::TensorT>();
|
||||
label->nodeType = lite::NodeType_ValueNode;
|
||||
label->format = schema::Format_NHWC;
|
||||
label->dataType = TypeId::kNumberTypeFloat32;
|
||||
label->dims = {BATCH_SIZE * NUM_CLASSES};
|
||||
label->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(label));
|
||||
}
|
||||
// tensor 7 - Softmaxentropy
|
||||
auto input7 = std::make_unique<schema::TensorT>();
|
||||
input7->nodeType = lite::NodeType_Parameter;
|
||||
input7->format = schema::Format_NHWC;
|
||||
input7->dataType = TypeId::kNumberTypeFloat32;
|
||||
input7->dims = {BATCH_SIZE, NUM_CLASSES};
|
||||
input7->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input7));
|
||||
// tensor 8 - biasGrad
|
||||
auto input8 = std::make_unique<schema::TensorT>();
|
||||
input8->nodeType = lite::NodeType_Parameter;
|
||||
input8->format = schema::Format_NHWC;
|
||||
input8->dataType = TypeId::kNumberTypeFloat32;
|
||||
input8->dims = {NUM_CLASSES};
|
||||
input8->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input8));
|
||||
// tensor 9 - matmul2
|
||||
auto input9 = std::make_unique<schema::TensorT>();
|
||||
input9->nodeType = lite::NodeType_Parameter;
|
||||
input9->format = schema::Format_NHWC;
|
||||
input9->dataType = TypeId::kNumberTypeFloat32;
|
||||
input9->dims = {NUM_CLASSES, FEATURE_SIZE};
|
||||
input9->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input9));
|
||||
// tensor 10 weights accumulate
|
||||
auto input10 = std::make_unique<schema::TensorT>();
|
||||
input10->nodeType = lite::NodeType_ValueNode;
|
||||
input10->format = schema::Format_NHWC;
|
||||
input10->dataType = TypeId::kNumberTypeFloat32;
|
||||
input10->dims = {NUM_CLASSES, FEATURE_SIZE};
|
||||
input10->offset = -1;
|
||||
size_t input10_size = NUM_CLASSES * FEATURE_SIZE * sizeof(float);
|
||||
input10->data.resize(input10_size);
|
||||
std::fill(input10->data.data(), input10->data.data() + input10_size, 0.f);
|
||||
meta_graph->allTensors.emplace_back(std::move(input10));
|
||||
// tensor 11 - lr
|
||||
{
|
||||
auto lr = std::make_unique<schema::TensorT>();
|
||||
lr->nodeType = lite::NodeType_ValueNode;
|
||||
lr->format = schema::Format_NHWC;
|
||||
lr->dataType = TypeId::kNumberTypeFloat32;
|
||||
lr->dims = {1};
|
||||
lr->offset = -1;
|
||||
lr->data.resize(sizeof(float));
|
||||
float *data = reinterpret_cast<float *>(lr->data.data());
|
||||
*data = 0.01f;
|
||||
meta_graph->allTensors.emplace_back(std::move(lr));
|
||||
}
|
||||
// tensor 12 - momentum
|
||||
{
|
||||
auto input12 = std::make_unique<schema::TensorT>();
|
||||
input12->nodeType = lite::NodeType_ValueNode;
|
||||
input12->format = schema::Format_NHWC;
|
||||
input12->dataType = TypeId::kNumberTypeFloat32;
|
||||
input12->dims = {1};
|
||||
input12->offset = -1;
|
||||
input12->data.resize(sizeof(float));
|
||||
float *data = reinterpret_cast<float *>(input12->data.data());
|
||||
*data = 0.f;
|
||||
meta_graph->allTensors.emplace_back(std::move(input12));
|
||||
}
|
||||
// tensor 13 - bias accumulate
|
||||
auto input13 = std::make_unique<schema::TensorT>();
|
||||
input13->nodeType = lite::NodeType_ValueNode;
|
||||
input13->format = schema::Format_NHWC;
|
||||
input13->dataType = TypeId::kNumberTypeFloat32;
|
||||
input13->dims = {NUM_CLASSES};
|
||||
input13->offset = -1;
|
||||
size_t input13_size = NUM_CLASSES * sizeof(float);
|
||||
input13->data.resize(input13_size);
|
||||
std::fill(input13->data.data(), input13->data.data() + input13_size, 0.f);
|
||||
meta_graph->allTensors.emplace_back(std::move(input13));
|
||||
|
||||
// tensor 14 - loss
|
||||
{
|
||||
auto loss14 = std::make_unique<schema::TensorT>();
|
||||
loss14->nodeType = lite::NodeType_ValueNode;
|
||||
loss14->format = schema::Format_NHWC;
|
||||
loss14->dataType = TypeId::kNumberTypeFloat32;
|
||||
loss14->dims = {1};
|
||||
loss14->offset = -1;
|
||||
loss14->data.resize(sizeof(float));
|
||||
float *data = reinterpret_cast<float *>(loss14->data.data());
|
||||
*data = 0.0f;
|
||||
meta_graph->allTensors.emplace_back(std::move(loss14));
|
||||
}
|
||||
|
||||
//================================================================
|
||||
buf = nullptr;
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
|
||||
std::cout << "build fb size= " << size << std::endl;
|
||||
|
||||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
|
||||
auto *model = mindspore::lite::Model::Import(content, size);
|
||||
ASSERT_NE(nullptr, model);
|
||||
|
||||
lite::Context context;
|
||||
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context.thread_num_ = 1;
|
||||
auto session = session::TrainSession::CreateSession(model, &context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
session->Train();
|
||||
session->Train(); // Just double check that calling Train twice does not cause a problem
|
||||
|
||||
auto inputs = session->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
auto inTensor = inputs.at(0);
|
||||
ASSERT_NE(nullptr, inTensor);
|
||||
auto data = inTensor->MutableData();
|
||||
//===================================================
|
||||
size_t input_size;
|
||||
std::string input_path = "./test_data/train/train_input_32_1000.bin";
|
||||
ReadFile(input_path.c_str(), &input_size, &buf);
|
||||
ASSERT_NE(nullptr, buf);
|
||||
auto input_data = reinterpret_cast<float *>(buf);
|
||||
ASSERT_NE(nullptr, input_data);
|
||||
//===================================================
|
||||
ASSERT_EQ(input_size, inTensor->Size());
|
||||
memcpy(data, input_data, input_size);
|
||||
delete[] buf;
|
||||
auto labelTensor = inputs.at(1);
|
||||
ASSERT_NE(nullptr, labelTensor);
|
||||
ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum());
|
||||
|
||||
auto labels = reinterpret_cast<float *>(labelTensor->MutableData());
|
||||
std::fill(labels, labels + labelTensor->ElementsNum(), 0.f);
|
||||
for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0;
|
||||
|
||||
auto ret = session->RunGraph();
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto outputs = session->GetOutputsByNodeName("SoftmaxCrossEntropyWithLogits");
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
auto outTensor = (outputs.at(0));
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
auto *outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
std::cout << "==============Initial=Loss=====================" << std::endl;
|
||||
std::cout << outData[0] << ", " << std::endl;
|
||||
|
||||
session->Eval();
|
||||
session->Eval(); // Just double check that calling eval twice does not cause a problem
|
||||
ret = session->RunGraph();
|
||||
outputs = session->GetOutputsByNodeName("BiasAdd");
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
outTensor = (outputs.at(0));
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
std::cout << "==============Scores=after-single=train========" << std::endl;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
std::cout << outData[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::string output_path = "./test_data/train/train_output_32_10.bin";
|
||||
auto error = RelativeOutputError(outData, output_path);
|
||||
EXPECT_LT(error, 2e-3);
|
||||
|
||||
ret = session->RunGraph();
|
||||
auto all_output_tensors = session->GetOutputs();
|
||||
outTensor = (all_output_tensors["5"]);
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
std::cout << "==============Scores=eval-second-time==========" << std::endl;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
std::cout << outData[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
error = RelativeOutputError(outData, output_path);
|
||||
EXPECT_LT(error, 2e-3);
|
||||
|
||||
session->Train();
|
||||
session->Eval(); // do some more zig-zags
|
||||
ret = session->RunGraph();
|
||||
outTensor = session->GetOutputByTensorName("5");
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
std::cout << "==============Scores=Just Checking 3rd time====" << std::endl;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
std::cout << outData[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
error = RelativeOutputError(outData, output_path);
|
||||
EXPECT_LT(error, 2e-3);
|
||||
}
|
||||
|
||||
int32_t fileIterator(mindspore::session::TrainSession *session, const std::string &path,
|
||||
std::function<int32_t(mindspore::session::TrainSession *session, const std::string &)> cb) {
|
||||
int32_t res = 0;
|
||||
|
@ -517,10 +101,7 @@ TEST_F(NetworkTest, efficient_net) {
|
|||
context->thread_num_ = 1;
|
||||
|
||||
std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms";
|
||||
|
||||
auto *model = mindspore::lite::Model::Import(net.c_str());
|
||||
ASSERT_NE(model, nullptr);
|
||||
auto session = session::TrainSession::CreateSession(model, context, false);
|
||||
auto session = session::TrainSession::CreateSession(net, context, false);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin";
|
||||
|
@ -560,15 +141,13 @@ TEST_F(NetworkTest, mobileface_net) {
|
|||
delete context;
|
||||
}
|
||||
|
||||
TEST_F(NetworkTest, setname) {
|
||||
TEST_F(NetworkTest, noname) {
|
||||
std::string net = "./test_data/nets/lenet_train.ms";
|
||||
lite::Context context;
|
||||
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context.thread_num_ = 1;
|
||||
|
||||
auto *model = mindspore::lite::Model::Import(net.c_str());
|
||||
ASSERT_NE(model, nullptr);
|
||||
auto session = mindspore::session::TrainSession::CreateSession(model, &context);
|
||||
auto session = mindspore::session::TrainSession::CreateSession(net, &context);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
auto tensors_map = session->GetOutputs();
|
||||
|
@ -577,20 +156,25 @@ TEST_F(NetworkTest, setname) {
|
|||
EXPECT_EQ(tensors_map.begin()->first, "24");
|
||||
EXPECT_EQ(tensor_names.size(), 1);
|
||||
EXPECT_EQ(tensor_names.at(0), "Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op107");
|
||||
delete session;
|
||||
}
|
||||
|
||||
auto res = session->SetLossName("nhwc");
|
||||
EXPECT_EQ(res, RET_OK);
|
||||
tensors_map = session->GetOutputs();
|
||||
tensor_names = session->GetOutputTensorNames();
|
||||
TEST_F(NetworkTest, setname) {
|
||||
std::string net = "./test_data/nets/lenet_train.ms";
|
||||
lite::Context context;
|
||||
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context.thread_num_ = 1;
|
||||
|
||||
lite::TrainCfg train_cfg;
|
||||
train_cfg.loss_name_ = "nhwc";
|
||||
|
||||
auto session = mindspore::session::TrainSession::CreateSession(net, &context, true, &train_cfg);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
auto tensors_map = session->GetOutputs();
|
||||
auto tensor_names = session->GetOutputTensorNames();
|
||||
EXPECT_EQ(tensors_map.begin()->first, "8");
|
||||
EXPECT_EQ(tensor_names.at(0), "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/MaxPool-op88");
|
||||
|
||||
res = session->SetLossName("loss");
|
||||
EXPECT_EQ(res, RET_OK);
|
||||
tensors_map = session->GetOutputs();
|
||||
tensor_names = session->GetOutputTensorNames();
|
||||
EXPECT_EQ(tensors_map.begin()->first, "24");
|
||||
EXPECT_EQ(tensor_names.at(0), "Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op107");
|
||||
delete session;
|
||||
}
|
||||
|
||||
|
|
|
@ -317,33 +317,35 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, int train_session
|
|||
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
|
||||
context.thread_num_ = flags_->num_threads_;
|
||||
|
||||
MS_LOG(INFO) << "start reading model file" << filename.c_str();
|
||||
std::cout << "start reading model file " << filename.c_str() << std::endl;
|
||||
auto *model = mindspore::lite::Model::Import(filename.c_str());
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "create model for train session failed";
|
||||
return RET_ERROR;
|
||||
TrainCfg train_cfg;
|
||||
if (flags_->loss_name_ != "") {
|
||||
train_cfg.loss_name_ = flags_->loss_name_;
|
||||
}
|
||||
|
||||
session::LiteSession *session = nullptr;
|
||||
session::TrainSession *t_session = nullptr;
|
||||
if (train_session) {
|
||||
t_session = session::TrainSession::CreateSession(model, &context);
|
||||
MS_LOG(INFO) << "CreateSession from model file" << filename.c_str();
|
||||
std::cout << "CreateSession from model file " << filename.c_str() << std::endl;
|
||||
t_session = session::TrainSession::CreateSession(filename, &context, true, &train_cfg);
|
||||
if (t_session == nullptr) {
|
||||
MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str();
|
||||
std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl;
|
||||
delete model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (flags_->loss_name_ != "") {
|
||||
t_session->SetLossName(flags_->loss_name_);
|
||||
}
|
||||
if (epochs > 0) {
|
||||
t_session->Train();
|
||||
}
|
||||
session = t_session;
|
||||
} else {
|
||||
MS_LOG(INFO) << "start reading model file" << filename.c_str();
|
||||
std::cout << "start reading model file " << filename.c_str() << std::endl;
|
||||
auto *model = mindspore::lite::Model::Import(filename.c_str());
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "create model for train session failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
session = session::LiteSession::CreateSession(&context);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str();
|
||||
|
@ -378,7 +380,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, int train_session
|
|||
std::cout << "Run MarkPerformance error: " << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
SaveModels(t_session, model); // save file if flags are on
|
||||
SaveModels(t_session); // save file if flags are on
|
||||
}
|
||||
if (!flags_->data_file_.empty()) {
|
||||
if (t_session != nullptr) {
|
||||
|
@ -406,10 +408,10 @@ int NetTrain::RunNetTrain() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int NetTrain::SaveModels(session::TrainSession *session, mindspore::lite::Model *model) {
|
||||
int NetTrain::SaveModels(session::TrainSession *session) {
|
||||
if (!flags_->export_file_.empty()) {
|
||||
auto ret = Model::Export(model, flags_->export_file_.c_str());
|
||||
if (ret != RET_OK) {
|
||||
auto status = session->Export(flags_->export_file_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SaveToFile error";
|
||||
std::cout << "Run SaveToFile error";
|
||||
return RET_ERROR;
|
||||
|
@ -417,7 +419,7 @@ int NetTrain::SaveModels(session::TrainSession *session, mindspore::lite::Model
|
|||
}
|
||||
if (!flags_->inference_file_.empty()) {
|
||||
auto tick = GetTimeUs();
|
||||
auto status = session->ExportInference(flags_->inference_file_);
|
||||
auto status = session->Export(flags_->inference_file_, lite::MT_INFERENCE);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Save model error: " << status;
|
||||
std::cout << "Save model error: " << status << std::endl;
|
||||
|
|
|
@ -196,7 +196,7 @@ class MS_API NetTrain {
|
|||
|
||||
int MarkAccuracy(session::LiteSession *lite_session);
|
||||
int CompareOutput(const session::LiteSession &lite_session);
|
||||
int SaveModels(session::TrainSession *session, mindspore::lite::Model *model);
|
||||
int SaveModels(session::TrainSession *session);
|
||||
int CheckExecutionOfSavedModels();
|
||||
NetTrainFlags *flags_;
|
||||
|
||||
|
|
Loading…
Reference in New Issue