unify lite session train and infer session

This commit is contained in:
zhengjun10 2021-07-07 17:23:27 +08:00
parent aaa80736a6
commit dcb564e5f5
10 changed files with 70 additions and 38 deletions

View File

@ -29,6 +29,7 @@
#include "include/train/ckpt_saver.h"
#include "include/train/lr_scheduler.h"
#include "include/train/accuracy_metrics.h"
#include "include/train/train_session.h"
#include "include/train/classification_train_accuracy_monitor.h"
#include "src/utils.h"
#include "include/dataset/datasets.h"
@ -142,7 +143,7 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
session_ = mindspore::session::LiteSession::CreateTrainSession(ms_file_, &context, true);
session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true);
MS_ASSERT(session_ != nullptr);
session_->SetupVirtualBatch(virtual_batch_);

View File

@ -25,6 +25,7 @@
#include <iostream>
#include "include/context.h"
#include "include/lite_session.h"
#include "include/train/train_session.h"
#include "src/utils.h"
static unsigned int seed = time(NULL);
@ -77,7 +78,7 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = enable_fp16_;
context.thread_num_ = 1;
session_ = mindspore::session::LiteSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
session_ = mindspore::session::TrainSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
MS_ASSERT(session_ != nullptr);
auto inputs = session_->GetInputs();

View File

@ -128,29 +128,6 @@ class MS_API LiteSession {
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0;
/// \brief Static method to create a TrainSession object
///
/// \param[in] filename name of flatbuffer that holds the flatbuffer
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
/// \param[in] cfg training configuration, set to null for default configuration
///
/// \return Pointer of MindSpore LiteSession
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
/// \brief Static method to create a TransferSession object
///
/// \param[in] filename_backbone Filename to read backbone net flatbuffer from
/// \param[in] filename_head Filename to read head net flatbuffer from
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
///
/// \return Pointer of MindSpore LiteSession
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
const lite::Context *context, bool train_mode = false,
const lite::TrainCfg *cfg = nullptr);
/// \brief Set model to train mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int Train() { return mindspore::lite::RET_ERROR; }

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
#include <string>
#include "include/lite_session.h"
namespace mindspore {
namespace session {
class TrainSession {
public:
/// \brief Static method to create a TransferSession object
///
/// \param[in] filename_backbone Filename to read backbone net flatbuffer from
/// \param[in] filename_head Filename to read head net flatbuffer from
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
///
/// \return Pointer of MindSpore LiteSession
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
const lite::Context *context, bool train_mode = false,
const lite::TrainCfg *cfg = nullptr);
/// \brief Static method to create a TrainSession object
///
/// \param[in] filename name of flatbuffer that holds the flatbuffer
/// \param[in] context Defines the context of the session to be created
/// \param[in] train_mode training mode to initialize Session with
/// \param[in] cfg training configuration, set to null for default configuration
///
/// \return Pointer of MindSpore LiteSession
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_

View File

@ -16,7 +16,7 @@
#include <jni.h>
#include "common/ms_log.h"
#include "include/lite_session.h"
#include "include/train/train_session.h"
#include "include/train/train_cfg.h"
#include "include/errorcode.h"
@ -32,7 +32,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTra
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
auto session = mindspore::session::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
lite_context_ptr, train_mode, nullptr);
if (session == nullptr) {
MS_LOGE("CreateTrainSession failed");

View File

@ -747,8 +747,8 @@ int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &featu
}
} // namespace lite
session::LiteSession *session::LiteSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
bool train_mode, const lite::TrainCfg *cfg) {
auto session = std::make_unique<lite::TrainSession>();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";

View File

@ -22,6 +22,7 @@
#include <memory>
#include <map>
#include "include/train/train_cfg.h"
#include "include/train/train_session.h"
#include "src/lite_session.h"
/*

View File

@ -290,10 +290,10 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
return session;
}
session::LiteSession *session::LiteSession::CreateTransferSession(const std::string &filename_backbone,
const std::string &filename_head,
const lite::Context *ctxt, bool train_mode,
const lite::TrainCfg *cfg) {
session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
const std::string &filename_head,
const lite::Context *ctxt, bool train_mode,
const lite::TrainCfg *cfg) {
size_t size_head = 0;
size_t size_backbone = 0;
std::string filename = filename_head;

View File

@ -28,6 +28,7 @@
#include "include/context.h"
#include "include/errorcode.h"
#include "include/train/train_cfg.h"
#include "include/train/train_session.h"
#include "src/common/log_adapter.h"
#include "src/common/file_utils.h"
#include "src/kernel_registry.h"
@ -102,7 +103,7 @@ TEST_F(NetworkTest, efficient_net) {
context->thread_num_ = 1;
std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms";
auto session = session::LiteSession::CreateTrainSession(net, context, false);
auto session = session::TrainSession::CreateTrainSession(net, context, false);
ASSERT_NE(session, nullptr);
std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin";
@ -150,7 +151,7 @@ TEST_F(NetworkTest, noname) {
lite::TrainCfg cfg;
cfg.loss_name_ = "nhwc";
auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &cfg);
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &cfg);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();
auto tensor_names = session->GetOutputTensorNames();
@ -170,7 +171,7 @@ TEST_F(NetworkTest, setname) {
lite::TrainCfg train_cfg;
train_cfg.loss_name_ = "nhwc";
auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &train_cfg);
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
ASSERT_NE(session, nullptr);
auto tensors_map = session->GetOutputs();

View File

@ -30,6 +30,7 @@
#include "include/version.h"
#include "include/model.h"
#include "include/train/train_cfg.h"
#include "include/train/train_session.h"
namespace mindspore {
namespace lite {
@ -338,7 +339,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::LiteSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
if (session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
@ -349,7 +350,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
session = std::unique_ptr<session::LiteSession>(
session::LiteSession::CreateTrainSession(filename, &context, true, &train_cfg));
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
if (session == nullptr) {
MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;