unify lite session train and infer session
This commit is contained in:
parent
aaa80736a6
commit
dcb564e5f5
|
@ -29,6 +29,7 @@
|
|||
#include "include/train/ckpt_saver.h"
|
||||
#include "include/train/lr_scheduler.h"
|
||||
#include "include/train/accuracy_metrics.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/train/classification_train_accuracy_monitor.h"
|
||||
#include "src/utils.h"
|
||||
#include "include/dataset/datasets.h"
|
||||
|
@ -142,7 +143,7 @@ void NetRunner::InitAndFigureInputs() {
|
|||
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
|
||||
context.thread_num_ = 2;
|
||||
|
||||
session_ = mindspore::session::LiteSession::CreateTrainSession(ms_file_, &context, true);
|
||||
session_ = mindspore::session::TrainSession::CreateTrainSession(ms_file_, &context, true);
|
||||
MS_ASSERT(session_ != nullptr);
|
||||
|
||||
session_->SetupVirtualBatch(virtual_batch_);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <iostream>
|
||||
#include "include/context.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/utils.h"
|
||||
|
||||
static unsigned int seed = time(NULL);
|
||||
|
@ -77,7 +78,7 @@ void NetRunner::InitAndFigureInputs() {
|
|||
context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = enable_fp16_;
|
||||
context.thread_num_ = 1;
|
||||
|
||||
session_ = mindspore::session::LiteSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
|
||||
session_ = mindspore::session::TrainSession::CreateTransferSession(ms_backbone_file_, ms_head_file_, &context);
|
||||
MS_ASSERT(session_ != nullptr);
|
||||
|
||||
auto inputs = session_->GetInputs();
|
||||
|
|
|
@ -128,29 +128,6 @@ class MS_API LiteSession {
|
|||
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
||||
virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0;
|
||||
|
||||
/// \brief Static method to create a TrainSession object
|
||||
///
|
||||
/// \param[in] filename name of flatbuffer that holds the flatbuffer
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
/// \param[in] cfg training configuration, set to null for default configuration
|
||||
///
|
||||
/// \return Pointer of MindSpore LiteSession
|
||||
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
|
||||
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
|
||||
|
||||
/// \brief Static method to create a TransferSession object
|
||||
///
|
||||
/// \param[in] filename_backbone Filename to read backbone net flatbuffer from
|
||||
/// \param[in] filename_head Filename to read head net flatbuffer from
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
///
|
||||
/// \return Pointer of MindSpore LiteSession
|
||||
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
|
||||
const lite::Context *context, bool train_mode = false,
|
||||
const lite::TrainCfg *cfg = nullptr);
|
||||
|
||||
/// \brief Set model to train mode
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
|
||||
virtual int Train() { return mindspore::lite::RET_ERROR; }
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
||||
#include <string>
|
||||
#include "include/lite_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
class TrainSession {
|
||||
public:
|
||||
/// \brief Static method to create a TransferSession object
|
||||
///
|
||||
/// \param[in] filename_backbone Filename to read backbone net flatbuffer from
|
||||
/// \param[in] filename_head Filename to read head net flatbuffer from
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
///
|
||||
/// \return Pointer of MindSpore LiteSession
|
||||
static LiteSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head,
|
||||
const lite::Context *context, bool train_mode = false,
|
||||
const lite::TrainCfg *cfg = nullptr);
|
||||
|
||||
/// \brief Static method to create a TrainSession object
|
||||
///
|
||||
/// \param[in] filename name of flatbuffer that holds the flatbuffer
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
/// \param[in] train_mode training mode to initialize Session with
|
||||
/// \param[in] cfg training configuration, set to null for default configuration
|
||||
///
|
||||
/// \return Pointer of MindSpore LiteSession
|
||||
static LiteSession *CreateTrainSession(const std::string &filename, const lite::Context *context,
|
||||
bool train_mode = false, const lite::TrainCfg *cfg = nullptr);
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_SESSION_H_
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
|
@ -32,7 +32,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTra
|
|||
}
|
||||
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
|
||||
|
||||
auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
|
||||
auto session = mindspore::session::TrainSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
|
||||
lite_context_ptr, train_mode, nullptr);
|
||||
if (session == nullptr) {
|
||||
MS_LOGE("CreateTrainSession failed");
|
||||
|
|
|
@ -747,8 +747,8 @@ int TrainSession::UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &featu
|
|||
}
|
||||
} // namespace lite
|
||||
|
||||
session::LiteSession *session::LiteSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
session::LiteSession *session::TrainSession::CreateTrainSession(const std::string &fn, const lite::Context *context,
|
||||
bool train_mode, const lite::TrainCfg *cfg) {
|
||||
auto session = std::make_unique<lite::TrainSession>();
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/lite_session.h"
|
||||
|
||||
/*
|
||||
|
|
|
@ -290,10 +290,10 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
|
|||
return session;
|
||||
}
|
||||
|
||||
session::LiteSession *session::LiteSession::CreateTransferSession(const std::string &filename_backbone,
|
||||
const std::string &filename_head,
|
||||
const lite::Context *ctxt, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone,
|
||||
const std::string &filename_head,
|
||||
const lite::Context *ctxt, bool train_mode,
|
||||
const lite::TrainCfg *cfg) {
|
||||
size_t size_head = 0;
|
||||
size_t size_backbone = 0;
|
||||
std::string filename = filename_head;
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/train/train_session.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
@ -102,7 +103,7 @@ TEST_F(NetworkTest, efficient_net) {
|
|||
context->thread_num_ = 1;
|
||||
|
||||
std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms";
|
||||
auto session = session::LiteSession::CreateTrainSession(net, context, false);
|
||||
auto session = session::TrainSession::CreateTrainSession(net, context, false);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin";
|
||||
|
@ -150,7 +151,7 @@ TEST_F(NetworkTest, noname) {
|
|||
|
||||
lite::TrainCfg cfg;
|
||||
cfg.loss_name_ = "nhwc";
|
||||
auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &cfg);
|
||||
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &cfg);
|
||||
ASSERT_NE(session, nullptr);
|
||||
auto tensors_map = session->GetOutputs();
|
||||
auto tensor_names = session->GetOutputTensorNames();
|
||||
|
@ -170,7 +171,7 @@ TEST_F(NetworkTest, setname) {
|
|||
lite::TrainCfg train_cfg;
|
||||
train_cfg.loss_name_ = "nhwc";
|
||||
|
||||
auto session = mindspore::session::LiteSession::CreateTrainSession(net, &context, true, &train_cfg);
|
||||
auto session = mindspore::session::TrainSession::CreateTrainSession(net, &context, true, &train_cfg);
|
||||
ASSERT_NE(session, nullptr);
|
||||
|
||||
auto tensors_map = session->GetOutputs();
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "include/version.h"
|
||||
#include "include/model.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/train/train_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -338,7 +339,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
|
|||
MS_LOG(INFO) << "CreateTransferSession from models files" << filename << " and " << bb_filename;
|
||||
std::cout << "CreateTranferSession from model file " << filename << " and " << bb_filename << std::endl;
|
||||
session = std::unique_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
|
||||
session::TrainSession::CreateTransferSession(bb_filename, filename, &context, true, &train_cfg));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str();
|
||||
std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
|
||||
|
@ -349,7 +350,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
|
|||
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
|
||||
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
|
||||
session = std::unique_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateTrainSession(filename, &context, true, &train_cfg));
|
||||
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str();
|
||||
std::cout << "RunNetTrain CreateTrainSession failed while running " << model_name.c_str() << std::endl;
|
||||
|
|
Loading…
Reference in New Issue